fix(proxy_server.py): allow worker config to just be the config filepath

This commit is contained in:
Krrish Dholakia 2023-12-06 14:02:44 -08:00
parent 6b8d1a20f5
commit 0dff2ebf55
2 changed files with 87 additions and 18 deletions

View file

@ -598,22 +598,22 @@ def save_worker_config(**data):
os.environ["WORKER_CONFIG"] = json.dumps(data) os.environ["WORKER_CONFIG"] = json.dumps(data)
def initialize( def initialize(
model, model=None,
alias, alias=None,
api_base, api_base=None,
api_version, api_version=None,
debug, debug=False,
temperature, temperature=None,
max_tokens, max_tokens=None,
request_timeout, request_timeout=600,
max_budget, max_budget=None,
telemetry, telemetry=False,
drop_params, drop_params=True,
add_function_to_prompt, add_function_to_prompt=True,
headers, headers=None,
save, save=False,
config, use_queue=False,
use_queue config=None,
): ):
global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings, master_key, user_custom_auth global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings, master_key, user_custom_auth
generate_feedback_box() generate_feedback_box()
@ -737,9 +737,17 @@ def litellm_completion(*args, **kwargs):
async def startup_event(): async def startup_event():
global prisma_client, master_key global prisma_client, master_key
import json import json
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
worker_config = litellm.get_secret("WORKER_CONFIG")
print(f"worker_config: {worker_config}")
print_verbose(f"worker_config: {worker_config}") print_verbose(f"worker_config: {worker_config}")
initialize(**worker_config) # check if it's a valid file path
if os.path.isfile(worker_config):
initialize(config=worker_config)
else:
# if not, assume it's a json string
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
initialize(**worker_config)
print_verbose(f"prisma client - {prisma_client}") print_verbose(f"prisma client - {prisma_client}")
if prisma_client: if prisma_client:
await prisma_client.connect() await prisma_client.connect()

View file

@ -0,0 +1,61 @@
# #### What this tests ####
# # Allow the user to easily run the local proxy server with Gunicorn
## LOCAL TESTING ONLY
# import sys, os, subprocess
# import traceback
# from dotenv import load_dotenv
# load_dotenv()
# import os, io
# # this file is to test litellm/proxy
# sys.path.insert(
# 0, os.path.abspath("../..")
# ) # Adds the parent directory to the system path
# import pytest
# import litellm
# ### LOCAL Proxy Server INIT ###
# from litellm.proxy.proxy_server import save_worker_config # Replace with the actual module where your FastAPI router is defined
# filepath = os.path.dirname(os.path.abspath(__file__))
# config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml"
# def get_openai_info():
# return {
# "api_key": os.getenv("AZURE_API_KEY"),
# "api_base": os.getenv("AZURE_API_BASE"),
# }
# def run_server(host="0.0.0.0",port=8008,num_workers=None):
# if num_workers is None:
# # Set it to min(8,cpu_count())
# import multiprocessing
# num_workers = min(4,multiprocessing.cpu_count())
# ### LOAD KEYS ###
# # Load the Azure keys. For now get them from openai-usage
# azure_info = get_openai_info()
# print(f"Azure info:{azure_info}")
# os.environ["AZURE_API_KEY"] = azure_info['api_key']
# os.environ["AZURE_API_BASE"] = azure_info['api_base']
# os.environ["AZURE_API_VERSION"] = "2023-09-01-preview"
# ### SAVE CONFIG ###
# os.environ["WORKER_CONFIG"] = config_fp
# # In order for the app to behave well with signals, run it with gunicorn
# # The first argument must be the "name of the command run"
# cmd = f"gunicorn litellm.proxy.proxy_server:app --workers {num_workers} --worker-class uvicorn.workers.UvicornWorker --bind {host}:{port}"
# cmd = cmd.split()
# print(f"Running command: {cmd}")
# import sys
# sys.stdout.flush()
# sys.stderr.flush()
# # Make sure to propage env variables
# subprocess.run(cmd) # This line actually starts Gunicorn
# if __name__ == "__main__":
# run_server()