mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
(Security fix) - Upgrade to fastapi==0.115.5
(#7447)
* fix upgrade fast api * bump fastapi * update a proxy startup tests * remove unused test file * update tests * bump fast api
This commit is contained in:
parent
9d510d1907
commit
2c13b22705
8 changed files with 196 additions and 780 deletions
|
@ -102,6 +102,7 @@ def generate_feedback_box():
|
|||
|
||||
|
||||
from collections import defaultdict
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import litellm
|
||||
from litellm import Router
|
||||
|
@ -364,6 +365,178 @@ _description = (
|
|||
else f"Proxy Server to call 100+ LLMs in the OpenAI format. {custom_swagger_message}\n\n{ui_message}"
|
||||
)
|
||||
|
||||
|
||||
def cleanup_router_config_variables():
|
||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, user_custom_sso, use_background_health_checks, health_check_interval, prisma_client
|
||||
|
||||
# Set all variables to None
|
||||
master_key = None
|
||||
user_config_file_path = None
|
||||
otel_logging = None
|
||||
user_custom_auth = None
|
||||
user_custom_auth_path = None
|
||||
user_custom_key_generate = None
|
||||
user_custom_sso = None
|
||||
use_background_health_checks = None
|
||||
health_check_interval = None
|
||||
prisma_client = None
|
||||
|
||||
|
||||
async def proxy_shutdown_event():
|
||||
global prisma_client, master_key, user_custom_auth, user_custom_key_generate
|
||||
verbose_proxy_logger.info("Shutting down LiteLLM Proxy Server")
|
||||
if prisma_client:
|
||||
verbose_proxy_logger.debug("Disconnecting from Prisma")
|
||||
await prisma_client.disconnect()
|
||||
|
||||
if litellm.cache is not None:
|
||||
await litellm.cache.disconnect()
|
||||
|
||||
await jwt_handler.close()
|
||||
|
||||
if db_writer_client is not None:
|
||||
await db_writer_client.close()
|
||||
|
||||
# flush remaining langfuse logs
|
||||
if "langfuse" in litellm.success_callback:
|
||||
try:
|
||||
# flush langfuse logs on shutdow
|
||||
from litellm.utils import langFuseLogger
|
||||
|
||||
if langFuseLogger is not None:
|
||||
langFuseLogger.Langfuse.flush()
|
||||
except Exception:
|
||||
# [DO NOT BLOCK shutdown events for this]
|
||||
pass
|
||||
|
||||
## RESET CUSTOM VARIABLES ##
|
||||
cleanup_router_config_variables()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def proxy_startup_event(app: FastAPI):
|
||||
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db, premium_user, _license_check
|
||||
import json
|
||||
|
||||
init_verbose_loggers()
|
||||
|
||||
### LOAD MASTER KEY ###
|
||||
# check if master key set in environment - load from there
|
||||
master_key = get_secret("LITELLM_MASTER_KEY", None) # type: ignore
|
||||
# check if DATABASE_URL in environment - load from there
|
||||
if prisma_client is None:
|
||||
_db_url: Optional[str] = get_secret("DATABASE_URL", None) # type: ignore
|
||||
prisma_client = await ProxyStartupEvent._setup_prisma_client(
|
||||
database_url=_db_url,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
)
|
||||
|
||||
## CHECK PREMIUM USER
|
||||
verbose_proxy_logger.debug(
|
||||
"litellm.proxy.proxy_server.py::startup() - CHECKING PREMIUM USER - {}".format(
|
||||
premium_user
|
||||
)
|
||||
)
|
||||
if premium_user is False:
|
||||
premium_user = _license_check.is_premium()
|
||||
|
||||
### LOAD CONFIG ###
|
||||
worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore
|
||||
env_config_yaml: Optional[str] = get_secret_str("CONFIG_FILE_PATH")
|
||||
verbose_proxy_logger.debug("worker_config: %s", worker_config)
|
||||
# check if it's a valid file path
|
||||
if env_config_yaml is not None:
|
||||
if os.path.isfile(env_config_yaml) and proxy_config.is_yaml(
|
||||
config_file_path=env_config_yaml
|
||||
):
|
||||
(
|
||||
llm_router,
|
||||
llm_model_list,
|
||||
general_settings,
|
||||
) = await proxy_config.load_config(
|
||||
router=llm_router, config_file_path=env_config_yaml
|
||||
)
|
||||
elif worker_config is not None:
|
||||
if (
|
||||
isinstance(worker_config, str)
|
||||
and os.path.isfile(worker_config)
|
||||
and proxy_config.is_yaml(config_file_path=worker_config)
|
||||
):
|
||||
(
|
||||
llm_router,
|
||||
llm_model_list,
|
||||
general_settings,
|
||||
) = await proxy_config.load_config(
|
||||
router=llm_router, config_file_path=worker_config
|
||||
)
|
||||
elif os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None and isinstance(
|
||||
worker_config, str
|
||||
):
|
||||
(
|
||||
llm_router,
|
||||
llm_model_list,
|
||||
general_settings,
|
||||
) = await proxy_config.load_config(
|
||||
router=llm_router, config_file_path=worker_config
|
||||
)
|
||||
elif isinstance(worker_config, dict):
|
||||
await initialize(**worker_config)
|
||||
else:
|
||||
# if not, assume it's a json string
|
||||
worker_config = json.loads(worker_config)
|
||||
if isinstance(worker_config, dict):
|
||||
await initialize(**worker_config)
|
||||
|
||||
ProxyStartupEvent._initialize_startup_logging(
|
||||
llm_router=llm_router,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
redis_usage_cache=redis_usage_cache,
|
||||
)
|
||||
|
||||
## JWT AUTH ##
|
||||
ProxyStartupEvent._initialize_jwt_auth(
|
||||
general_settings=general_settings,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
)
|
||||
|
||||
if use_background_health_checks:
|
||||
asyncio.create_task(
|
||||
_run_background_health_check()
|
||||
) # start the background health check coroutine.
|
||||
|
||||
if prompt_injection_detection_obj is not None: # [TODO] - REFACTOR THIS
|
||||
prompt_injection_detection_obj.update_environment(router=llm_router)
|
||||
|
||||
verbose_proxy_logger.debug("prisma_client: %s", prisma_client)
|
||||
if prisma_client is not None and master_key is not None:
|
||||
ProxyStartupEvent._add_master_key_hash_to_db(
|
||||
master_key=master_key,
|
||||
prisma_client=prisma_client,
|
||||
litellm_proxy_admin_name=litellm_proxy_admin_name,
|
||||
general_settings=general_settings,
|
||||
)
|
||||
|
||||
if prisma_client is not None and litellm.max_budget > 0:
|
||||
ProxyStartupEvent._add_proxy_budget_to_db(
|
||||
litellm_proxy_budget_name=litellm_proxy_admin_name
|
||||
)
|
||||
|
||||
### START BATCH WRITING DB + CHECKING NEW MODELS###
|
||||
if prisma_client is not None:
|
||||
await ProxyStartupEvent.initialize_scheduled_background_jobs(
|
||||
general_settings=general_settings,
|
||||
prisma_client=prisma_client,
|
||||
proxy_budget_rescheduler_min_time=proxy_budget_rescheduler_min_time,
|
||||
proxy_budget_rescheduler_max_time=proxy_budget_rescheduler_max_time,
|
||||
proxy_batch_write_at=proxy_batch_write_at,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
yield
|
||||
await proxy_shutdown_event()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
docs_url=_get_docs_url(),
|
||||
redoc_url=_get_redoc_url(),
|
||||
|
@ -371,6 +544,7 @@ app = FastAPI(
|
|||
description=_description,
|
||||
version=version,
|
||||
root_path=server_root_path, # check if user passed root path, FastAPI defaults this value to ""
|
||||
lifespan=proxy_startup_event,
|
||||
)
|
||||
|
||||
|
||||
|
@ -2998,128 +3172,6 @@ class ProxyStartupEvent:
|
|||
return prisma_client
|
||||
|
||||
|
||||
@router.on_event("startup")
|
||||
async def startup_event():
|
||||
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db, premium_user, _license_check
|
||||
import json
|
||||
|
||||
init_verbose_loggers()
|
||||
|
||||
### LOAD MASTER KEY ###
|
||||
# check if master key set in environment - load from there
|
||||
master_key = get_secret("LITELLM_MASTER_KEY", None) # type: ignore
|
||||
# check if DATABASE_URL in environment - load from there
|
||||
if prisma_client is None:
|
||||
_db_url: Optional[str] = get_secret("DATABASE_URL", None) # type: ignore
|
||||
prisma_client = await ProxyStartupEvent._setup_prisma_client(
|
||||
database_url=_db_url,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
)
|
||||
|
||||
## CHECK PREMIUM USER
|
||||
verbose_proxy_logger.debug(
|
||||
"litellm.proxy.proxy_server.py::startup() - CHECKING PREMIUM USER - {}".format(
|
||||
premium_user
|
||||
)
|
||||
)
|
||||
if premium_user is False:
|
||||
premium_user = _license_check.is_premium()
|
||||
|
||||
### LOAD CONFIG ###
|
||||
worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore
|
||||
env_config_yaml: Optional[str] = get_secret_str("CONFIG_FILE_PATH")
|
||||
verbose_proxy_logger.debug("worker_config: %s", worker_config)
|
||||
# check if it's a valid file path
|
||||
if env_config_yaml is not None:
|
||||
if os.path.isfile(env_config_yaml) and proxy_config.is_yaml(
|
||||
config_file_path=env_config_yaml
|
||||
):
|
||||
(
|
||||
llm_router,
|
||||
llm_model_list,
|
||||
general_settings,
|
||||
) = await proxy_config.load_config(
|
||||
router=llm_router, config_file_path=env_config_yaml
|
||||
)
|
||||
elif worker_config is not None:
|
||||
if (
|
||||
isinstance(worker_config, str)
|
||||
and os.path.isfile(worker_config)
|
||||
and proxy_config.is_yaml(config_file_path=worker_config)
|
||||
):
|
||||
(
|
||||
llm_router,
|
||||
llm_model_list,
|
||||
general_settings,
|
||||
) = await proxy_config.load_config(
|
||||
router=llm_router, config_file_path=worker_config
|
||||
)
|
||||
elif os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None and isinstance(
|
||||
worker_config, str
|
||||
):
|
||||
(
|
||||
llm_router,
|
||||
llm_model_list,
|
||||
general_settings,
|
||||
) = await proxy_config.load_config(
|
||||
router=llm_router, config_file_path=worker_config
|
||||
)
|
||||
elif isinstance(worker_config, dict):
|
||||
await initialize(**worker_config)
|
||||
else:
|
||||
# if not, assume it's a json string
|
||||
worker_config = json.loads(worker_config)
|
||||
if isinstance(worker_config, dict):
|
||||
await initialize(**worker_config)
|
||||
|
||||
ProxyStartupEvent._initialize_startup_logging(
|
||||
llm_router=llm_router,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
redis_usage_cache=redis_usage_cache,
|
||||
)
|
||||
|
||||
## JWT AUTH ##
|
||||
ProxyStartupEvent._initialize_jwt_auth(
|
||||
general_settings=general_settings,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
)
|
||||
|
||||
if use_background_health_checks:
|
||||
asyncio.create_task(
|
||||
_run_background_health_check()
|
||||
) # start the background health check coroutine.
|
||||
|
||||
if prompt_injection_detection_obj is not None: # [TODO] - REFACTOR THIS
|
||||
prompt_injection_detection_obj.update_environment(router=llm_router)
|
||||
|
||||
verbose_proxy_logger.debug("prisma_client: %s", prisma_client)
|
||||
if prisma_client is not None and master_key is not None:
|
||||
ProxyStartupEvent._add_master_key_hash_to_db(
|
||||
master_key=master_key,
|
||||
prisma_client=prisma_client,
|
||||
litellm_proxy_admin_name=litellm_proxy_admin_name,
|
||||
general_settings=general_settings,
|
||||
)
|
||||
|
||||
if prisma_client is not None and litellm.max_budget > 0:
|
||||
ProxyStartupEvent._add_proxy_budget_to_db(
|
||||
litellm_proxy_budget_name=litellm_proxy_admin_name
|
||||
)
|
||||
|
||||
### START BATCH WRITING DB + CHECKING NEW MODELS###
|
||||
if prisma_client is not None:
|
||||
await ProxyStartupEvent.initialize_scheduled_background_jobs(
|
||||
general_settings=general_settings,
|
||||
prisma_client=prisma_client,
|
||||
proxy_budget_rescheduler_min_time=proxy_budget_rescheduler_min_time,
|
||||
proxy_budget_rescheduler_max_time=proxy_budget_rescheduler_max_time,
|
||||
proxy_batch_write_at=proxy_batch_write_at,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
|
||||
#### API ENDPOINTS ####
|
||||
@router.get(
|
||||
"/v1/models", dependencies=[Depends(user_api_key_auth)], tags=["model management"]
|
||||
|
@ -8556,54 +8608,6 @@ async def get_routes():
|
|||
# return {"token": token}
|
||||
|
||||
|
||||
@router.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
global prisma_client, master_key, user_custom_auth, user_custom_key_generate
|
||||
verbose_proxy_logger.info("Shutting down LiteLLM Proxy Server")
|
||||
if prisma_client:
|
||||
verbose_proxy_logger.debug("Disconnecting from Prisma")
|
||||
await prisma_client.disconnect()
|
||||
|
||||
if litellm.cache is not None:
|
||||
await litellm.cache.disconnect()
|
||||
|
||||
await jwt_handler.close()
|
||||
|
||||
if db_writer_client is not None:
|
||||
await db_writer_client.close()
|
||||
|
||||
# flush remaining langfuse logs
|
||||
if "langfuse" in litellm.success_callback:
|
||||
try:
|
||||
# flush langfuse logs on shutdow
|
||||
from litellm.utils import langFuseLogger
|
||||
|
||||
if langFuseLogger is not None:
|
||||
langFuseLogger.Langfuse.flush()
|
||||
except Exception:
|
||||
# [DO NOT BLOCK shutdown events for this]
|
||||
pass
|
||||
|
||||
## RESET CUSTOM VARIABLES ##
|
||||
cleanup_router_config_variables()
|
||||
|
||||
|
||||
def cleanup_router_config_variables():
|
||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, user_custom_sso, use_background_health_checks, health_check_interval, prisma_client
|
||||
|
||||
# Set all variables to None
|
||||
master_key = None
|
||||
user_config_file_path = None
|
||||
otel_logging = None
|
||||
user_custom_auth = None
|
||||
user_custom_auth_path = None
|
||||
user_custom_key_generate = None
|
||||
user_custom_sso = None
|
||||
use_background_health_checks = None
|
||||
health_check_interval = None
|
||||
prisma_client = None
|
||||
|
||||
|
||||
app.include_router(router)
|
||||
app.include_router(batches_router)
|
||||
app.include_router(rerank_router)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue