diff --git a/litellm/integrations/prometheus.py b/litellm/integrations/prometheus.py index bf19c364e..54563ab61 100644 --- a/litellm/integrations/prometheus.py +++ b/litellm/integrations/prometheus.py @@ -346,12 +346,14 @@ class PrometheusLogger(CustomLogger): standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( "standard_logging_object" ) + if standard_logging_payload is None or not isinstance( standard_logging_payload, dict ): raise ValueError( f"standard_logging_object is required, got={standard_logging_payload}" ) + model = kwargs.get("model", "") litellm_params = kwargs.get("litellm_params", {}) or {} _metadata = litellm_params.get("metadata", {}) diff --git a/litellm/litellm_core_utils/llm_response_utils/convert_dict_to_response.py b/litellm/litellm_core_utils/llm_response_utils/convert_dict_to_response.py index 95749037d..76077ad46 100644 --- a/litellm/litellm_core_utils/llm_response_utils/convert_dict_to_response.py +++ b/litellm/litellm_core_utils/llm_response_utils/convert_dict_to_response.py @@ -260,6 +260,7 @@ def convert_to_model_response_object( # noqa: PLR0915 ] = None, # used for supporting 'json_schema' on older models ): received_args = locals() + additional_headers = get_response_headers(_response_headers) if hidden_params is None: @@ -448,11 +449,13 @@ def convert_to_model_response_object( # noqa: PLR0915 ): if response_object is None: raise Exception("Error in response object format") + return LiteLLMResponseObjectHandler.convert_to_image_response( response_object=response_object, model_response_object=model_response_object, hidden_params=hidden_params, ) + elif response_type == "audio_transcription" and ( model_response_object is None or isinstance(model_response_object, TranscriptionResponse) diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 2cdf35b70..ee56b3366 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -48,4 +48,3 @@ router_settings: redis_host: os.environ/REDIS_HOST redis_port: os.environ/REDIS_PORT redis_password: os.environ/REDIS_PASSWORD - diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 163d2ff25..bcdd4e86c 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -638,18 +638,6 @@ def _resolve_pydantic_type(typ) -> List: return typs -def prisma_setup(database_url: Optional[str]): - global prisma_client, proxy_logging_obj, user_api_key_cache - - if database_url is not None: - try: - prisma_client = PrismaClient( - database_url=database_url, proxy_logging_obj=proxy_logging_obj - ) - except Exception as e: - raise e - - def load_from_azure_key_vault(use_azure_key_vault: bool = False): if use_azure_key_vault is False: return @@ -1543,7 +1531,6 @@ class ProxyConfig: ## INIT PROXY REDIS USAGE CLIENT ## redis_usage_cache = litellm.cache.cache - async def get_config(self, config_file_path: Optional[str] = None) -> dict: """ Load config file @@ -2796,137 +2783,55 @@ def giveup(e): return result -@router.on_event("startup") -async def startup_event(): # noqa: PLR0915 - global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db, premium_user, _license_check - import json +class ProxyStartupEvent: + @classmethod + def _initialize_startup_logging( + cls, + llm_router: Optional[litellm.Router], + proxy_logging_obj: ProxyLogging, + redis_usage_cache: Optional[RedisCache], + ): + """Initialize logging and alerting on startup""" + ## COST TRACKING ## + cost_tracking() - init_verbose_loggers() + ## Error Tracking ## + error_tracking() - ### LOAD MASTER KEY ### - # check if master key set in environment - load from there - master_key = get_secret("LITELLM_MASTER_KEY", None) # type: ignore - # check if DATABASE_URL in environment - load from there - if prisma_client is None: - _db_url: Optional[str] = get_secret("DATABASE_URL", None) # type: ignore - prisma_setup(database_url=_db_url) + proxy_logging_obj.startup_event( + llm_router=llm_router, redis_usage_cache=redis_usage_cache + ) - ### LOAD CONFIG ### - worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore - env_config_yaml: Optional[str] = get_secret_str("CONFIG_FILE_PATH") - verbose_proxy_logger.debug("worker_config: %s", worker_config) - # check if it's a valid file path - if env_config_yaml is not None: - if os.path.isfile(env_config_yaml) and proxy_config.is_yaml( - config_file_path=env_config_yaml - ): - ( - llm_router, - llm_model_list, - general_settings, - ) = await proxy_config.load_config( - router=llm_router, config_file_path=env_config_yaml - ) - elif worker_config is not None: - if ( - isinstance(worker_config, str) - and os.path.isfile(worker_config) - and proxy_config.is_yaml(config_file_path=worker_config) - ): - ( - llm_router, - llm_model_list, - general_settings, - ) = await proxy_config.load_config( - router=llm_router, config_file_path=worker_config - ) - elif os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None and isinstance( - worker_config, str - ): - ( - llm_router, - llm_model_list, - general_settings, - ) = await proxy_config.load_config( - router=llm_router, config_file_path=worker_config - ) - elif isinstance(worker_config, dict): - await initialize(**worker_config) + @classmethod + def _initialize_jwt_auth( + cls, + general_settings: dict, + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, + ): + """Initialize JWT auth on startup""" + if general_settings.get("litellm_jwtauth", None) is not None: + for k, v in general_settings["litellm_jwtauth"].items(): + if isinstance(v, str) and v.startswith("os.environ/"): + general_settings["litellm_jwtauth"][k] = get_secret(v) + litellm_jwtauth = LiteLLM_JWTAuth(**general_settings["litellm_jwtauth"]) else: - # if not, assume it's a json string - worker_config = json.loads(worker_config) - if isinstance(worker_config, dict): - await initialize(**worker_config) - - ## CHECK PREMIUM USER - verbose_proxy_logger.debug( - "litellm.proxy.proxy_server.py::startup() - CHECKING PREMIUM USER - {}".format( - premium_user + litellm_jwtauth = LiteLLM_JWTAuth() + jwt_handler.update_environment( + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + litellm_jwtauth=litellm_jwtauth, ) - ) - if premium_user is False: - premium_user = _license_check.is_premium() - verbose_proxy_logger.debug( - "litellm.proxy.proxy_server.py::startup() - PREMIUM USER value - {}".format( - premium_user - ) - ) - - ## COST TRACKING ## - cost_tracking() - - ## Error Tracking ## - error_tracking() - - ## UPDATE SLACK ALERTING ## - proxy_logging_obj.slack_alerting_instance.update_values(llm_router=llm_router) - - db_writer_client = HTTPHandler() - - ## UPDATE INTERNAL USAGE CACHE ## - proxy_logging_obj.update_values( - redis_cache=redis_usage_cache - ) # used by parallel request limiter for rate limiting keys across instances - - proxy_logging_obj._init_litellm_callbacks( - llm_router=llm_router - ) # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made - - if "daily_reports" in proxy_logging_obj.slack_alerting_instance.alert_types: - asyncio.create_task( - proxy_logging_obj.slack_alerting_instance._run_scheduled_daily_report( - llm_router=llm_router - ) - ) # RUN DAILY REPORT (if scheduled) - - ## JWT AUTH ## - if general_settings.get("litellm_jwtauth", None) is not None: - for k, v in general_settings["litellm_jwtauth"].items(): - if isinstance(v, str) and v.startswith("os.environ/"): - general_settings["litellm_jwtauth"][k] = get_secret(v) - litellm_jwtauth = LiteLLM_JWTAuth(**general_settings["litellm_jwtauth"]) - else: - litellm_jwtauth = LiteLLM_JWTAuth() - jwt_handler.update_environment( - prisma_client=prisma_client, - user_api_key_cache=user_api_key_cache, - litellm_jwtauth=litellm_jwtauth, - ) - - if use_background_health_checks: - asyncio.create_task( - _run_background_health_check() - ) # start the background health check coroutine. - - if prompt_injection_detection_obj is not None: - prompt_injection_detection_obj.update_environment(router=llm_router) - - verbose_proxy_logger.debug("prisma_client: %s", prisma_client) - if prisma_client is not None: - await prisma_client.connect() - - if prisma_client is not None and master_key is not None: + @classmethod + def _add_master_key_hash_to_db( + cls, + master_key: str, + prisma_client: PrismaClient, + litellm_proxy_admin_name: str, + general_settings: dict, + ): + """Adds master key hash to db for cost tracking""" if os.getenv("PROXY_ADMIN_ID", None) is not None: litellm_proxy_admin_name = os.getenv( "PROXY_ADMIN_ID", litellm_proxy_admin_name @@ -2951,7 +2856,9 @@ async def startup_event(): # noqa: PLR0915 ) asyncio.create_task(task_1) - if prisma_client is not None and litellm.max_budget > 0: + @classmethod + def _add_proxy_budget_to_db(cls, litellm_proxy_budget_name: str): + """Adds a global proxy budget to db""" if litellm.budget_duration is None: raise Exception( "budget_duration not set on Proxy. budget_duration is required to use max_budget." @@ -2977,8 +2884,18 @@ async def startup_event(): # noqa: PLR0915 ) ) - ### START BATCH WRITING DB + CHECKING NEW MODELS### - if prisma_client is not None: + @classmethod + async def initialize_scheduled_background_jobs( + cls, + general_settings: dict, + prisma_client: PrismaClient, + proxy_budget_rescheduler_min_time: int, + proxy_budget_rescheduler_max_time: int, + proxy_batch_write_at: int, + proxy_logging_obj: ProxyLogging, + store_model_in_db: bool, + ): + """Initializes scheduled background jobs""" scheduler = AsyncIOScheduler() interval = random.randint( proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time @@ -3067,6 +2984,165 @@ async def startup_event(): # noqa: PLR0915 scheduler.start() + @classmethod + def _setup_prisma_client( + cls, + database_url: Optional[str], + proxy_logging_obj: ProxyLogging, + user_api_key_cache: DualCache, + ) -> Optional[PrismaClient]: + """ + - Sets up prisma client + - Adds necessary views to proxy + """ + prisma_client: Optional[PrismaClient] = None + if database_url is not None: + try: + prisma_client = PrismaClient( + database_url=database_url, proxy_logging_obj=proxy_logging_obj + ) + except Exception as e: + raise e + + ## Add necessary views to proxy ## + asyncio.create_task( + prisma_client.check_view_exists() + ) # check if all necessary views exist. Don't block execution + + return prisma_client + + +@router.on_event("startup") +async def startup_event(): + global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db, premium_user, _license_check + import json + + init_verbose_loggers() + + ### LOAD MASTER KEY ### + # check if master key set in environment - load from there + master_key = get_secret("LITELLM_MASTER_KEY", None) # type: ignore + # check if DATABASE_URL in environment - load from there + if prisma_client is None: + _db_url: Optional[str] = get_secret("DATABASE_URL", None) # type: ignore + prisma_client = ProxyStartupEvent._setup_prisma_client( + database_url=_db_url, + proxy_logging_obj=proxy_logging_obj, + user_api_key_cache=user_api_key_cache, + ) + + ### LOAD CONFIG ### + worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore + env_config_yaml: Optional[str] = get_secret_str("CONFIG_FILE_PATH") + verbose_proxy_logger.debug("worker_config: %s", worker_config) + # check if it's a valid file path + if env_config_yaml is not None: + if os.path.isfile(env_config_yaml) and proxy_config.is_yaml( + config_file_path=env_config_yaml + ): + ( + llm_router, + llm_model_list, + general_settings, + ) = await proxy_config.load_config( + router=llm_router, config_file_path=env_config_yaml + ) + elif worker_config is not None: + if ( + isinstance(worker_config, str) + and os.path.isfile(worker_config) + and proxy_config.is_yaml(config_file_path=worker_config) + ): + ( + llm_router, + llm_model_list, + general_settings, + ) = await proxy_config.load_config( + router=llm_router, config_file_path=worker_config + ) + elif os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None and isinstance( + worker_config, str + ): + ( + llm_router, + llm_model_list, + general_settings, + ) = await proxy_config.load_config( + router=llm_router, config_file_path=worker_config + ) + elif isinstance(worker_config, dict): + await initialize(**worker_config) + else: + # if not, assume it's a json string + worker_config = json.loads(worker_config) + if isinstance(worker_config, dict): + await initialize(**worker_config) + + ## CHECK PREMIUM USER + verbose_proxy_logger.debug( + "litellm.proxy.proxy_server.py::startup() - CHECKING PREMIUM USER - {}".format( + premium_user + ) + ) + if premium_user is False: + premium_user = _license_check.is_premium() + + verbose_proxy_logger.debug( + "litellm.proxy.proxy_server.py::startup() - PREMIUM USER value - {}".format( + premium_user + ) + ) + + ProxyStartupEvent._initialize_startup_logging( + llm_router=llm_router, + proxy_logging_obj=proxy_logging_obj, + redis_usage_cache=redis_usage_cache, + ) + + ## JWT AUTH ## + ProxyStartupEvent._initialize_jwt_auth( + general_settings=general_settings, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + ) + + if use_background_health_checks: + asyncio.create_task( + _run_background_health_check() + ) # start the background health check coroutine. + + if prompt_injection_detection_obj is not None: # [TODO] - REFACTOR THIS + prompt_injection_detection_obj.update_environment(router=llm_router) + + verbose_proxy_logger.debug("prisma_client: %s", prisma_client) + if prisma_client is not None: + await prisma_client.connect() + + if prisma_client is not None and master_key is not None: + ProxyStartupEvent._add_master_key_hash_to_db( + master_key=master_key, + prisma_client=prisma_client, + litellm_proxy_admin_name=litellm_proxy_admin_name, + general_settings=general_settings, + ) + + if prisma_client is not None and litellm.max_budget > 0: + ProxyStartupEvent._add_proxy_budget_to_db( + litellm_proxy_budget_name=litellm_proxy_admin_name + ) + + ### START BATCH WRITING DB + CHECKING NEW MODELS### + if prisma_client is not None: + await ProxyStartupEvent.initialize_scheduled_background_jobs( + general_settings=general_settings, + prisma_client=prisma_client, + proxy_budget_rescheduler_min_time=proxy_budget_rescheduler_min_time, + proxy_budget_rescheduler_max_time=proxy_budget_rescheduler_max_time, + proxy_batch_write_at=proxy_batch_write_at, + proxy_logging_obj=proxy_logging_obj, + store_model_in_db=store_model_in_db, + ) + #### API ENDPOINTS #### @router.get( diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index d41cf2dfb..6089bf0c3 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -349,6 +349,31 @@ class ProxyLogging: ) self.premium_user = premium_user + def startup_event( + self, + llm_router: Optional[litellm.Router], + redis_usage_cache: Optional[RedisCache], + ): + """Initialize logging and alerting on proxy startup""" + ## UPDATE SLACK ALERTING ## + self.slack_alerting_instance.update_values(llm_router=llm_router) + + ## UPDATE INTERNAL USAGE CACHE ## + self.update_values( + redis_cache=redis_usage_cache + ) # used by parallel request limiter for rate limiting keys across instances + + self._init_litellm_callbacks( + llm_router=llm_router + ) # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made + + if "daily_reports" in self.slack_alerting_instance.alert_types: + asyncio.create_task( + self.slack_alerting_instance._run_scheduled_daily_report( + llm_router=llm_router + ) + ) # RUN DAILY REPORT (if scheduled) + def update_values( self, alerting: Optional[List] = None, diff --git a/tests/local_testing/test_proxy_server.py b/tests/local_testing/test_proxy_server.py index d76894ce6..e92d84c55 100644 --- a/tests/local_testing/test_proxy_server.py +++ b/tests/local_testing/test_proxy_server.py @@ -1894,3 +1894,27 @@ async def test_proxy_model_group_info_rerank(prisma_client): # asyncio.run(test()) # except Exception as e: # pytest.fail(f"An exception occurred - {str(e)}") + + +@pytest.mark.asyncio +async def test_proxy_server_prisma_setup(): + from litellm.proxy.proxy_server import ProxyStartupEvent + from litellm.proxy.utils import ProxyLogging + from litellm.caching import DualCache + + user_api_key_cache = DualCache() + + with patch.object( + litellm.proxy.proxy_server, "PrismaClient", new=MagicMock() + ) as mock_prisma_client: + mock_client = mock_prisma_client.return_value # This is the mocked instance + mock_client.check_view_exists = AsyncMock() # Mock the check_view_exists method + + ProxyStartupEvent._setup_prisma_client( + database_url=os.getenv("DATABASE_URL"), + proxy_logging_obj=ProxyLogging(user_api_key_cache=user_api_key_cache), + user_api_key_cache=user_api_key_cache, + ) + + await asyncio.sleep(1) + mock_client.check_view_exists.assert_called_once()