fix(proxy_server.py): allow worker config to just be the config filepath

This commit is contained in:
Krrish Dholakia 2023-12-06 14:02:44 -08:00
parent 368934d160
commit 346551da29
2 changed files with 87 additions and 18 deletions

View file

@ -598,22 +598,22 @@ def save_worker_config(**data):
os.environ["WORKER_CONFIG"] = json.dumps(data)
def initialize(
model,
alias,
api_base,
api_version,
debug,
temperature,
max_tokens,
request_timeout,
max_budget,
telemetry,
drop_params,
add_function_to_prompt,
headers,
save,
config,
use_queue
model=None,
alias=None,
api_base=None,
api_version=None,
debug=False,
temperature=None,
max_tokens=None,
request_timeout=600,
max_budget=None,
telemetry=False,
drop_params=True,
add_function_to_prompt=True,
headers=None,
save=False,
use_queue=False,
config=None,
):
global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings, master_key, user_custom_auth
generate_feedback_box()
@ -737,9 +737,17 @@ def litellm_completion(*args, **kwargs):
async def startup_event():
global prisma_client, master_key
import json
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
worker_config = litellm.get_secret("WORKER_CONFIG")
print(f"worker_config: {worker_config}")
print_verbose(f"worker_config: {worker_config}")
initialize(**worker_config)
# check if it's a valid file path
if os.path.isfile(worker_config):
initialize(config=worker_config)
else:
# if not, assume it's a json string
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
initialize(**worker_config)
print_verbose(f"prisma client - {prisma_client}")
if prisma_client:
await prisma_client.connect()