Merge pull request #1367 from BerriAI/litellm_proxy_startup

fix(proxy_server.py): add support for passing in config file via worker_config directly + testing
This commit is contained in:
Krish Dholakia 2024-01-08 19:46:48 +05:30 committed by GitHub
commit e949a2ada3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 66 additions and 1 deletions

View file

@ -507,6 +507,14 @@ class ProxyConfig:
def __init__(self) -> None:
pass
def is_yaml(self, config_file_path: str) -> bool:
if not os.path.isfile(config_file_path):
return False
_, file_extension = os.path.splitext(config_file_path)
return file_extension.lower() == '.yaml' or file_extension.lower() == '.yml'
async def get_config(self, config_file_path: Optional[str] = None) -> dict:
global prisma_client, user_config_file_path
@ -1156,6 +1164,13 @@ async def startup_event():
verbose_proxy_logger.debug(f"worker_config: {worker_config}")
# check if it's a valid file path
if os.path.isfile(worker_config):
if proxy_config.is_yaml(config_file_path=worker_config):
(
llm_router,
llm_model_list,
general_settings,
) = await proxy_config.load_config(router=llm_router, config_file_path=worker_config)
else:
await initialize(**worker_config)
else:
# if not, assume it's a json string

View file

@ -0,0 +1,50 @@
# What this tests
## This tests the proxy server startup
import sys, os, json
import traceback
from dotenv import load_dotenv
load_dotenv()
import os, io
# this file is to test litellm/proxy
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest, logging, asyncio
import litellm
from litellm.proxy.proxy_server import (
router,
save_worker_config,
initialize,
startup_event,
llm_model_list
)
def test_proxy_gunicorn_startup_direct_config():
"""
gunicorn startup requires the config to be passed in via environment variables
We support saving either the config or the dict as an environment variable.
Test both approaches
"""
filepath = os.path.dirname(os.path.abspath(__file__))
# test with worker_config = config yaml
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
os.environ["WORKER_CONFIG"] = config_fp
asyncio.run(startup_event())
def test_proxy_gunicorn_startup_config_dict():
filepath = os.path.dirname(os.path.abspath(__file__))
# test with worker_config = config yaml
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
# test with worker_config = dict
worker_config = {"config": config_fp}
os.environ["WORKER_CONFIG"] = json.dumps(worker_config)
asyncio.run(startup_event())
# test_proxy_gunicorn_startup()