diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 62286a4271..9d2066f2e3 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -598,22 +598,22 @@ def save_worker_config(**data): os.environ["WORKER_CONFIG"] = json.dumps(data) def initialize( - model, - alias, - api_base, - api_version, - debug, - temperature, - max_tokens, - request_timeout, - max_budget, - telemetry, - drop_params, - add_function_to_prompt, - headers, - save, - config, - use_queue + model=None, + alias=None, + api_base=None, + api_version=None, + debug=False, + temperature=None, + max_tokens=None, + request_timeout=600, + max_budget=None, + telemetry=False, + drop_params=True, + add_function_to_prompt=True, + headers=None, + save=False, + use_queue=False, + config=None, ): global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings, master_key, user_custom_auth generate_feedback_box() @@ -737,9 +737,17 @@ def litellm_completion(*args, **kwargs): async def startup_event(): global prisma_client, master_key import json - worker_config = json.loads(os.getenv("WORKER_CONFIG")) + + worker_config = litellm.get_secret("WORKER_CONFIG") + print(f"worker_config: {worker_config}") print_verbose(f"worker_config: {worker_config}") - initialize(**worker_config) + # check if it's a valid file path + if os.path.isfile(worker_config): + initialize(config=worker_config) + else: + # if not, assume it's a json string + worker_config = json.loads(os.getenv("WORKER_CONFIG")) + initialize(**worker_config) print_verbose(f"prisma client - {prisma_client}") if prisma_client: await prisma_client.connect() diff --git a/litellm/tests/test_proxy_gunicorn.py b/litellm/tests/test_proxy_gunicorn.py new file mode 100644 index 0000000000..4d96ac259f --- /dev/null +++ b/litellm/tests/test_proxy_gunicorn.py @@ -0,0 +1,61 @@ +# #### What this tests #### +# # Allow the user to easily run the local proxy server with Gunicorn +## LOCAL TESTING ONLY +# import sys, os, subprocess +# import traceback +# from dotenv import load_dotenv + +# load_dotenv() +# import os, io + +# # this file is to test litellm/proxy + +# sys.path.insert( +# 0, os.path.abspath("../..") +# ) # Adds the parent directory to the system path +# import pytest +# import litellm + +# ### LOCAL Proxy Server INIT ### +# from litellm.proxy.proxy_server import save_worker_config # Replace with the actual module where your FastAPI router is defined +# filepath = os.path.dirname(os.path.abspath(__file__)) +# config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml" +# def get_openai_info(): +# return { +# "api_key": os.getenv("AZURE_API_KEY"), +# "api_base": os.getenv("AZURE_API_BASE"), +# } + +# def run_server(host="0.0.0.0",port=8008,num_workers=None): +# if num_workers is None: +# # Set it to min(8,cpu_count()) +# import multiprocessing +# num_workers = min(4,multiprocessing.cpu_count()) + +# ### LOAD KEYS ### + +# # Load the Azure keys. For now get them from openai-usage +# azure_info = get_openai_info() +# print(f"Azure info:{azure_info}") +# os.environ["AZURE_API_KEY"] = azure_info['api_key'] +# os.environ["AZURE_API_BASE"] = azure_info['api_base'] +# os.environ["AZURE_API_VERSION"] = "2023-09-01-preview" + +# ### SAVE CONFIG ### + +# os.environ["WORKER_CONFIG"] = config_fp + +# # In order for the app to behave well with signals, run it with gunicorn +# # The first argument must be the "name of the command run" +# cmd = f"gunicorn litellm.proxy.proxy_server:app --workers {num_workers} --worker-class uvicorn.workers.UvicornWorker --bind {host}:{port}" +# cmd = cmd.split() +# print(f"Running command: {cmd}") +# import sys +# sys.stdout.flush() +# sys.stderr.flush() + +# # Make sure to propage env variables +# subprocess.run(cmd) # This line actually starts Gunicorn + +# if __name__ == "__main__": +# run_server() \ No newline at end of file