(Security fix) - Upgrade to fastapi==0.115.5 (#7447)

* fix upgrade fast api

* bump fastapi

* update a proxy startup tests

* remove unused test file

* update tests

* bump fast api
This commit is contained in:
Ishaan Jaff 2024-12-28 17:08:19 -08:00 committed by GitHub
parent 9d510d1907
commit 2c13b22705
8 changed files with 196 additions and 780 deletions

View file

@ -102,6 +102,7 @@ def generate_feedback_box():
from collections import defaultdict
from contextlib import asynccontextmanager
import litellm
from litellm import Router
@ -364,6 +365,178 @@ _description = (
else f"Proxy Server to call 100+ LLMs in the OpenAI format. {custom_swagger_message}\n\n{ui_message}"
)
def cleanup_router_config_variables():
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, user_custom_sso, use_background_health_checks, health_check_interval, prisma_client
# Set all variables to None
master_key = None
user_config_file_path = None
otel_logging = None
user_custom_auth = None
user_custom_auth_path = None
user_custom_key_generate = None
user_custom_sso = None
use_background_health_checks = None
health_check_interval = None
prisma_client = None
async def proxy_shutdown_event():
global prisma_client, master_key, user_custom_auth, user_custom_key_generate
verbose_proxy_logger.info("Shutting down LiteLLM Proxy Server")
if prisma_client:
verbose_proxy_logger.debug("Disconnecting from Prisma")
await prisma_client.disconnect()
if litellm.cache is not None:
await litellm.cache.disconnect()
await jwt_handler.close()
if db_writer_client is not None:
await db_writer_client.close()
# flush remaining langfuse logs
if "langfuse" in litellm.success_callback:
try:
# flush langfuse logs on shutdow
from litellm.utils import langFuseLogger
if langFuseLogger is not None:
langFuseLogger.Langfuse.flush()
except Exception:
# [DO NOT BLOCK shutdown events for this]
pass
## RESET CUSTOM VARIABLES ##
cleanup_router_config_variables()
@asynccontextmanager
async def proxy_startup_event(app: FastAPI):
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db, premium_user, _license_check
import json
init_verbose_loggers()
### LOAD MASTER KEY ###
# check if master key set in environment - load from there
master_key = get_secret("LITELLM_MASTER_KEY", None) # type: ignore
# check if DATABASE_URL in environment - load from there
if prisma_client is None:
_db_url: Optional[str] = get_secret("DATABASE_URL", None) # type: ignore
prisma_client = await ProxyStartupEvent._setup_prisma_client(
database_url=_db_url,
proxy_logging_obj=proxy_logging_obj,
user_api_key_cache=user_api_key_cache,
)
## CHECK PREMIUM USER
verbose_proxy_logger.debug(
"litellm.proxy.proxy_server.py::startup() - CHECKING PREMIUM USER - {}".format(
premium_user
)
)
if premium_user is False:
premium_user = _license_check.is_premium()
### LOAD CONFIG ###
worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore
env_config_yaml: Optional[str] = get_secret_str("CONFIG_FILE_PATH")
verbose_proxy_logger.debug("worker_config: %s", worker_config)
# check if it's a valid file path
if env_config_yaml is not None:
if os.path.isfile(env_config_yaml) and proxy_config.is_yaml(
config_file_path=env_config_yaml
):
(
llm_router,
llm_model_list,
general_settings,
) = await proxy_config.load_config(
router=llm_router, config_file_path=env_config_yaml
)
elif worker_config is not None:
if (
isinstance(worker_config, str)
and os.path.isfile(worker_config)
and proxy_config.is_yaml(config_file_path=worker_config)
):
(
llm_router,
llm_model_list,
general_settings,
) = await proxy_config.load_config(
router=llm_router, config_file_path=worker_config
)
elif os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None and isinstance(
worker_config, str
):
(
llm_router,
llm_model_list,
general_settings,
) = await proxy_config.load_config(
router=llm_router, config_file_path=worker_config
)
elif isinstance(worker_config, dict):
await initialize(**worker_config)
else:
# if not, assume it's a json string
worker_config = json.loads(worker_config)
if isinstance(worker_config, dict):
await initialize(**worker_config)
ProxyStartupEvent._initialize_startup_logging(
llm_router=llm_router,
proxy_logging_obj=proxy_logging_obj,
redis_usage_cache=redis_usage_cache,
)
## JWT AUTH ##
ProxyStartupEvent._initialize_jwt_auth(
general_settings=general_settings,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
)
if use_background_health_checks:
asyncio.create_task(
_run_background_health_check()
) # start the background health check coroutine.
if prompt_injection_detection_obj is not None: # [TODO] - REFACTOR THIS
prompt_injection_detection_obj.update_environment(router=llm_router)
verbose_proxy_logger.debug("prisma_client: %s", prisma_client)
if prisma_client is not None and master_key is not None:
ProxyStartupEvent._add_master_key_hash_to_db(
master_key=master_key,
prisma_client=prisma_client,
litellm_proxy_admin_name=litellm_proxy_admin_name,
general_settings=general_settings,
)
if prisma_client is not None and litellm.max_budget > 0:
ProxyStartupEvent._add_proxy_budget_to_db(
litellm_proxy_budget_name=litellm_proxy_admin_name
)
### START BATCH WRITING DB + CHECKING NEW MODELS###
if prisma_client is not None:
await ProxyStartupEvent.initialize_scheduled_background_jobs(
general_settings=general_settings,
prisma_client=prisma_client,
proxy_budget_rescheduler_min_time=proxy_budget_rescheduler_min_time,
proxy_budget_rescheduler_max_time=proxy_budget_rescheduler_max_time,
proxy_batch_write_at=proxy_batch_write_at,
proxy_logging_obj=proxy_logging_obj,
)
yield
await proxy_shutdown_event()
app = FastAPI(
docs_url=_get_docs_url(),
redoc_url=_get_redoc_url(),
@ -371,6 +544,7 @@ app = FastAPI(
description=_description,
version=version,
root_path=server_root_path, # check if user passed root path, FastAPI defaults this value to ""
lifespan=proxy_startup_event,
)
@ -2998,128 +3172,6 @@ class ProxyStartupEvent:
return prisma_client
@router.on_event("startup")
async def startup_event():
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db, premium_user, _license_check
import json
init_verbose_loggers()
### LOAD MASTER KEY ###
# check if master key set in environment - load from there
master_key = get_secret("LITELLM_MASTER_KEY", None) # type: ignore
# check if DATABASE_URL in environment - load from there
if prisma_client is None:
_db_url: Optional[str] = get_secret("DATABASE_URL", None) # type: ignore
prisma_client = await ProxyStartupEvent._setup_prisma_client(
database_url=_db_url,
proxy_logging_obj=proxy_logging_obj,
user_api_key_cache=user_api_key_cache,
)
## CHECK PREMIUM USER
verbose_proxy_logger.debug(
"litellm.proxy.proxy_server.py::startup() - CHECKING PREMIUM USER - {}".format(
premium_user
)
)
if premium_user is False:
premium_user = _license_check.is_premium()
### LOAD CONFIG ###
worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore
env_config_yaml: Optional[str] = get_secret_str("CONFIG_FILE_PATH")
verbose_proxy_logger.debug("worker_config: %s", worker_config)
# check if it's a valid file path
if env_config_yaml is not None:
if os.path.isfile(env_config_yaml) and proxy_config.is_yaml(
config_file_path=env_config_yaml
):
(
llm_router,
llm_model_list,
general_settings,
) = await proxy_config.load_config(
router=llm_router, config_file_path=env_config_yaml
)
elif worker_config is not None:
if (
isinstance(worker_config, str)
and os.path.isfile(worker_config)
and proxy_config.is_yaml(config_file_path=worker_config)
):
(
llm_router,
llm_model_list,
general_settings,
) = await proxy_config.load_config(
router=llm_router, config_file_path=worker_config
)
elif os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None and isinstance(
worker_config, str
):
(
llm_router,
llm_model_list,
general_settings,
) = await proxy_config.load_config(
router=llm_router, config_file_path=worker_config
)
elif isinstance(worker_config, dict):
await initialize(**worker_config)
else:
# if not, assume it's a json string
worker_config = json.loads(worker_config)
if isinstance(worker_config, dict):
await initialize(**worker_config)
ProxyStartupEvent._initialize_startup_logging(
llm_router=llm_router,
proxy_logging_obj=proxy_logging_obj,
redis_usage_cache=redis_usage_cache,
)
## JWT AUTH ##
ProxyStartupEvent._initialize_jwt_auth(
general_settings=general_settings,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
)
if use_background_health_checks:
asyncio.create_task(
_run_background_health_check()
) # start the background health check coroutine.
if prompt_injection_detection_obj is not None: # [TODO] - REFACTOR THIS
prompt_injection_detection_obj.update_environment(router=llm_router)
verbose_proxy_logger.debug("prisma_client: %s", prisma_client)
if prisma_client is not None and master_key is not None:
ProxyStartupEvent._add_master_key_hash_to_db(
master_key=master_key,
prisma_client=prisma_client,
litellm_proxy_admin_name=litellm_proxy_admin_name,
general_settings=general_settings,
)
if prisma_client is not None and litellm.max_budget > 0:
ProxyStartupEvent._add_proxy_budget_to_db(
litellm_proxy_budget_name=litellm_proxy_admin_name
)
### START BATCH WRITING DB + CHECKING NEW MODELS###
if prisma_client is not None:
await ProxyStartupEvent.initialize_scheduled_background_jobs(
general_settings=general_settings,
prisma_client=prisma_client,
proxy_budget_rescheduler_min_time=proxy_budget_rescheduler_min_time,
proxy_budget_rescheduler_max_time=proxy_budget_rescheduler_max_time,
proxy_batch_write_at=proxy_batch_write_at,
proxy_logging_obj=proxy_logging_obj,
)
#### API ENDPOINTS ####
@router.get(
"/v1/models", dependencies=[Depends(user_api_key_auth)], tags=["model management"]
@ -8556,54 +8608,6 @@ async def get_routes():
# return {"token": token}
@router.on_event("shutdown")
async def shutdown_event():
global prisma_client, master_key, user_custom_auth, user_custom_key_generate
verbose_proxy_logger.info("Shutting down LiteLLM Proxy Server")
if prisma_client:
verbose_proxy_logger.debug("Disconnecting from Prisma")
await prisma_client.disconnect()
if litellm.cache is not None:
await litellm.cache.disconnect()
await jwt_handler.close()
if db_writer_client is not None:
await db_writer_client.close()
# flush remaining langfuse logs
if "langfuse" in litellm.success_callback:
try:
# flush langfuse logs on shutdow
from litellm.utils import langFuseLogger
if langFuseLogger is not None:
langFuseLogger.Langfuse.flush()
except Exception:
# [DO NOT BLOCK shutdown events for this]
pass
## RESET CUSTOM VARIABLES ##
cleanup_router_config_variables()
def cleanup_router_config_variables():
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, user_custom_sso, use_background_health_checks, health_check_interval, prisma_client
# Set all variables to None
master_key = None
user_config_file_path = None
otel_logging = None
user_custom_auth = None
user_custom_auth_path = None
user_custom_key_generate = None
user_custom_sso = None
use_background_health_checks = None
health_check_interval = None
prisma_client = None
app.include_router(router)
app.include_router(batches_router)
app.include_router(rerank_router)