From 4e310051c7404e5a2f5c41e0643cb9d03ffcc016 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Thu, 24 Oct 2024 22:02:15 -0700 Subject: [PATCH] =?UTF-8?q?feat(proxy=5Fserver.py):=20check=20if=20views?= =?UTF-8?q?=20exist=20on=20proxy=20server=20startup=20+=E2=80=A6=20(#6360)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(proxy_server.py): check if views exist on proxy server startup + refactor startup event logic to <50 LOC * refactor(redis_cache.py): use a default cache value when writing to r… (#6358) * refactor(redis_cache.py): use a default cache value when writing to redis prevent redis from blowing up in high traffic * refactor(redis_cache.py): refactor all cache writes to use self.get_ttl ensures default ttl always used when writing to redis Prevents redis db from blowing up in prod * feat(proxy_cli.py): add new 'log_config' cli param (#6352) * feat(proxy_cli.py): add new 'log_config' cli param Allows passing logging.conf to uvicorn on startup * docs(cli.md): add logging conf to uvicorn cli docs * fix(get_llm_provider_logic.py): fix default api base for litellm_proxy Fixes https://github.com/BerriAI/litellm/issues/6332 * feat(openai_like/embedding): Add support for jina ai embeddings Closes https://github.com/BerriAI/litellm/issues/6337 * docs(deploy.md): update entrypoint.sh filepath post-refactor Fixes outdated docs * feat(prometheus.py): emit time_to_first_token metric on prometheus Closes https://github.com/BerriAI/litellm/issues/6334 * fix(prometheus.py): only emit time to first token metric if stream is True enables more accurate ttft usage * test: handle vertex api instability * fix(get_llm_provider_logic.py): fix import * fix(openai.py): fix deepinfra default api base * fix(anthropic/transformation.py): remove anthropic beta header (#6361) * docs(sidebars.js): add jina ai embedding to docs * docs(sidebars.js): add jina ai to left nav * bump: version 1.50.1 → 1.50.2 * langfuse use helper for get_langfuse_logging_config * Refactor: apply early return (#6369) * (refactor) remove berrispendLogger - unused logging integration (#6363) * fix remove berrispendLogger * remove unused clickhouse logger * fix docs configs.md * (fix) standard logging metadata + add unit testing (#6366) * fix setting StandardLoggingMetadata * add unit testing for standard logging metadata * fix otel logging test * fix linting * fix typing * Revert "(fix) standard logging metadata + add unit testing (#6366)" (#6381) This reverts commit 8359cb6fa9bf7b0bf4f3df630cf8666adffa2813. * add new 35 mode lcard (#6378) * Add claude 3 5 sonnet 20241022 models for all provides (#6380) * Add Claude 3.5 v2 on Amazon Bedrock and Vertex AI. * added anthropic/claude-3-5-sonnet-20241022 * add new 35 mode lcard --------- Co-authored-by: Paul Gauthier Co-authored-by: lowjiansheng <15527690+lowjiansheng@users.noreply.github.com> * test(skip-flaky-google-context-caching-test): google is not reliable. their sample code is also not working * test(test_alangfuse.py): handle flaky langfuse test better * (feat) Arize - Allow using Arize HTTP endpoint (#6364) * arize use helper for get_arize_opentelemetry_config * use helper to get Arize OTEL config * arize add helpers for arize * docs allow using arize http endpoint * fix importing OTEL for Arize * use static methods for ArizeLogger * fix ArizeLogger tests * Litellm dev 10 22 2024 (#6384) * fix(utils.py): add 'disallowed_special' for token counting on .encode() Fixes error when '< endoftext >' in string * Revert "(fix) standard logging metadata + add unit testing (#6366)" (#6381) This reverts commit 8359cb6fa9bf7b0bf4f3df630cf8666adffa2813. * add new 35 mode lcard (#6378) * Add claude 3 5 sonnet 20241022 models for all provides (#6380) * Add Claude 3.5 v2 on Amazon Bedrock and Vertex AI. * added anthropic/claude-3-5-sonnet-20241022 * add new 35 mode lcard --------- Co-authored-by: Paul Gauthier Co-authored-by: lowjiansheng <15527690+lowjiansheng@users.noreply.github.com> * test(skip-flaky-google-context-caching-test): google is not reliable. their sample code is also not working * Fix metadata being overwritten in speech() (#6295) * fix: adding missing redis cluster kwargs (#6318) Co-authored-by: Ali Arian * Add support for `max_completion_tokens` in Azure OpenAI (#6376) Now that Azure supports `max_completion_tokens`, no need for special handling for this param and let it pass thru. More details: https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models?tabs=python-secure#api-support * build(model_prices_and_context_window.json): add voyage-finance-2 pricing Closes https://github.com/BerriAI/litellm/issues/6371 * build(model_prices_and_context_window.json): fix llama3.1 pricing model name on map Closes https://github.com/BerriAI/litellm/issues/6310 * feat(realtime_streaming.py): just log specific events Closes https://github.com/BerriAI/litellm/issues/6267 * fix(utils.py): more robust checking if unmapped vertex anthropic model belongs to that family of models Fixes https://github.com/BerriAI/litellm/issues/6383 * Fix Ollama stream handling for tool calls with None content (#6155) * test(test_max_completions): update test now that azure supports 'max_completion_tokens' * fix(handler.py): fix linting error --------- Co-authored-by: Ishaan Jaff Co-authored-by: Low Jian Sheng <15527690+lowjiansheng@users.noreply.github.com> Co-authored-by: David Manouchehri Co-authored-by: Paul Gauthier Co-authored-by: John HU Co-authored-by: Ali Arian <113945203+ali-arian@users.noreply.github.com> Co-authored-by: Ali Arian Co-authored-by: Anand Taralika <46954145+taralika@users.noreply.github.com> Co-authored-by: Nolan Tremelling <34580718+NolanTrem@users.noreply.github.com> * bump: version 1.50.2 → 1.50.3 * build(deps): bump http-proxy-middleware in /docs/my-website (#6395) Bumps [http-proxy-middleware](https://github.com/chimurai/http-proxy-middleware) from 2.0.6 to 2.0.7. - [Release notes](https://github.com/chimurai/http-proxy-middleware/releases) - [Changelog](https://github.com/chimurai/http-proxy-middleware/blob/v2.0.7/CHANGELOG.md) - [Commits](https://github.com/chimurai/http-proxy-middleware/compare/v2.0.6...v2.0.7) --- updated-dependencies: - dependency-name: http-proxy-middleware dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * (docs + testing) Correctly document the timeout value used by litellm proxy is 6000 seconds + add to best practices for prod (#6339) * fix docs use documented timeout * document request timeout * add test for litellm.request_timeout * add test for checking value of timeout * (refactor) move convert dict to model response to llm_response_utils/ (#6393) * refactor move convert dict to model response * fix imports * fix import _handle_invalid_parallel_tool_calls * (refactor) litellm.Router client initialization utils (#6394) * refactor InitalizeOpenAISDKClient * use helper func for _should_create_openai_sdk_client_for_model * use static methods for set client on litellm router * reduce LOC in _get_client_initialization_params * fix _should_create_openai_sdk_client_for_model * code quality fix * test test_should_create_openai_sdk_client_for_model * test test_get_client_initialization_params_openai * fix mypy linting errors * fix OpenAISDKClientInitializationParams * test_get_client_initialization_params_all_env_vars * test_get_client_initialization_params_azure_ai_studio_mistral * test_get_client_initialization_params_default_values * fix _get_client_initialization_params * (fix) Langfuse key based logging (#6372) * langfuse use helper for get_langfuse_logging_config * fix get_langfuse_logger_for_request * fix import * fix get_langfuse_logger_for_request * test_get_langfuse_logger_for_request_with_dynamic_params * unit testing for test_get_langfuse_logger_for_request_with_no_dynamic_params * parameterized langfuse testing * fix langfuse test * fix langfuse logging * fix test_aaalangfuse_logging_metadata * fix langfuse log metadata test * fix langfuse logger * use create_langfuse_logger_from_credentials * fix test_get_langfuse_logger_for_request_with_no_dynamic_params * fix correct langfuse/ folder structure * use static methods for langfuse logger * add commment on langfuse handler * fix linting error * add unit testing for langfuse logging * fix linting * fix failure handler langfuse * Revert "(refactor) litellm.Router client initialization utils (#6394)" (#6403) This reverts commit b70147f63b5ad95d90be371a50c6248fe21b20e8. * def test_text_completion_with_echo(stream): (#6401) test * fix linting - remove # noqa PLR0915 from fixed function * test: cleanup codestral tests - backend api unavailable * (refactor) prometheus async_log_success_event to be under 100 LOC (#6416) * unit testig for prometheus * unit testing for success metrics * use 1 helper for _increment_token_metrics * use helper for _increment_remaining_budget_metrics * use _increment_remaining_budget_metrics * use _increment_top_level_request_and_spend_metrics * use helper for _set_latency_metrics * remove noqa violation * fix test prometheus * test prometheus * unit testing for all prometheus helper functions * fix prom unit tests * fix unit tests prometheus * fix unit test prom * (refactor) router - use static methods for client init utils (#6420) * use InitalizeOpenAISDKClient * use InitalizeOpenAISDKClient static method * fix # noqa: PLR0915 * (code cleanup) remove unused and undocumented logging integrations - litedebugger, berrispend (#6406) * code cleanup remove unused and undocumented code files * fix unused logging integrations cleanup * bump: version 1.50.3 → 1.50.4 --------- Signed-off-by: dependabot[bot] Co-authored-by: Ishaan Jaff Co-authored-by: Hakan Taşköprü Co-authored-by: Low Jian Sheng <15527690+lowjiansheng@users.noreply.github.com> Co-authored-by: David Manouchehri Co-authored-by: Paul Gauthier Co-authored-by: John HU Co-authored-by: Ali Arian <113945203+ali-arian@users.noreply.github.com> Co-authored-by: Ali Arian Co-authored-by: Anand Taralika <46954145+taralika@users.noreply.github.com> Co-authored-by: Nolan Tremelling <34580718+NolanTrem@users.noreply.github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- litellm/integrations/prometheus.py | 2 + .../convert_dict_to_response.py | 3 + litellm/proxy/_new_secret_config.yaml | 1 - litellm/proxy/proxy_server.py | 358 +++++++++++------- litellm/proxy/utils.py | 25 ++ tests/local_testing/test_proxy_server.py | 24 ++ 6 files changed, 271 insertions(+), 142 deletions(-) diff --git a/litellm/integrations/prometheus.py b/litellm/integrations/prometheus.py index bf19c364e..54563ab61 100644 --- a/litellm/integrations/prometheus.py +++ b/litellm/integrations/prometheus.py @@ -346,12 +346,14 @@ class PrometheusLogger(CustomLogger): standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( "standard_logging_object" ) + if standard_logging_payload is None or not isinstance( standard_logging_payload, dict ): raise ValueError( f"standard_logging_object is required, got={standard_logging_payload}" ) + model = kwargs.get("model", "") litellm_params = kwargs.get("litellm_params", {}) or {} _metadata = litellm_params.get("metadata", {}) diff --git a/litellm/litellm_core_utils/llm_response_utils/convert_dict_to_response.py b/litellm/litellm_core_utils/llm_response_utils/convert_dict_to_response.py index 95749037d..76077ad46 100644 --- a/litellm/litellm_core_utils/llm_response_utils/convert_dict_to_response.py +++ b/litellm/litellm_core_utils/llm_response_utils/convert_dict_to_response.py @@ -260,6 +260,7 @@ def convert_to_model_response_object( # noqa: PLR0915 ] = None, # used for supporting 'json_schema' on older models ): received_args = locals() + additional_headers = get_response_headers(_response_headers) if hidden_params is None: @@ -448,11 +449,13 @@ def convert_to_model_response_object( # noqa: PLR0915 ): if response_object is None: raise Exception("Error in response object format") + return LiteLLMResponseObjectHandler.convert_to_image_response( response_object=response_object, model_response_object=model_response_object, hidden_params=hidden_params, ) + elif response_type == "audio_transcription" and ( model_response_object is None or isinstance(model_response_object, TranscriptionResponse) diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 2cdf35b70..ee56b3366 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -48,4 +48,3 @@ router_settings: redis_host: os.environ/REDIS_HOST redis_port: os.environ/REDIS_PORT redis_password: os.environ/REDIS_PASSWORD - diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 163d2ff25..bcdd4e86c 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -638,18 +638,6 @@ def _resolve_pydantic_type(typ) -> List: return typs -def prisma_setup(database_url: Optional[str]): - global prisma_client, proxy_logging_obj, user_api_key_cache - - if database_url is not None: - try: - prisma_client = PrismaClient( - database_url=database_url, proxy_logging_obj=proxy_logging_obj - ) - except Exception as e: - raise e - - def load_from_azure_key_vault(use_azure_key_vault: bool = False): if use_azure_key_vault is False: return @@ -1543,7 +1531,6 @@ class ProxyConfig: ## INIT PROXY REDIS USAGE CLIENT ## redis_usage_cache = litellm.cache.cache - async def get_config(self, config_file_path: Optional[str] = None) -> dict: """ Load config file @@ -2796,137 +2783,55 @@ def giveup(e): return result -@router.on_event("startup") -async def startup_event(): # noqa: PLR0915 - global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db, premium_user, _license_check - import json +class ProxyStartupEvent: + @classmethod + def _initialize_startup_logging( + cls, + llm_router: Optional[litellm.Router], + proxy_logging_obj: ProxyLogging, + redis_usage_cache: Optional[RedisCache], + ): + """Initialize logging and alerting on startup""" + ## COST TRACKING ## + cost_tracking() - init_verbose_loggers() + ## Error Tracking ## + error_tracking() - ### LOAD MASTER KEY ### - # check if master key set in environment - load from there - master_key = get_secret("LITELLM_MASTER_KEY", None) # type: ignore - # check if DATABASE_URL in environment - load from there - if prisma_client is None: - _db_url: Optional[str] = get_secret("DATABASE_URL", None) # type: ignore - prisma_setup(database_url=_db_url) + proxy_logging_obj.startup_event( + llm_router=llm_router, redis_usage_cache=redis_usage_cache + ) - ### LOAD CONFIG ### - worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore - env_config_yaml: Optional[str] = get_secret_str("CONFIG_FILE_PATH") - verbose_proxy_logger.debug("worker_config: %s", worker_config) - # check if it's a valid file path - if env_config_yaml is not None: - if os.path.isfile(env_config_yaml) and proxy_config.is_yaml( - config_file_path=env_config_yaml - ): - ( - llm_router, - llm_model_list, - general_settings, - ) = await proxy_config.load_config( - router=llm_router, config_file_path=env_config_yaml - ) - elif worker_config is not None: - if ( - isinstance(worker_config, str) - and os.path.isfile(worker_config) - and proxy_config.is_yaml(config_file_path=worker_config) - ): - ( - llm_router, - llm_model_list, - general_settings, - ) = await proxy_config.load_config( - router=llm_router, config_file_path=worker_config - ) - elif os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None and isinstance( - worker_config, str - ): - ( - llm_router, - llm_model_list, - general_settings, - ) = await proxy_config.load_config( - router=llm_router, config_file_path=worker_config - ) - elif isinstance(worker_config, dict): - await initialize(**worker_config) + @classmethod + def _initialize_jwt_auth( + cls, + general_settings: dict, + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, + ): + """Initialize JWT auth on startup""" + if general_settings.get("litellm_jwtauth", None) is not None: + for k, v in general_settings["litellm_jwtauth"].items(): + if isinstance(v, str) and v.startswith("os.environ/"): + general_settings["litellm_jwtauth"][k] = get_secret(v) + litellm_jwtauth = LiteLLM_JWTAuth(**general_settings["litellm_jwtauth"]) else: - # if not, assume it's a json string - worker_config = json.loads(worker_config) - if isinstance(worker_config, dict): - await initialize(**worker_config) - - ## CHECK PREMIUM USER - verbose_proxy_logger.debug( - "litellm.proxy.proxy_server.py::startup() - CHECKING PREMIUM USER - {}".format( - premium_user + litellm_jwtauth = LiteLLM_JWTAuth() + jwt_handler.update_environment( + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + litellm_jwtauth=litellm_jwtauth, ) - ) - if premium_user is False: - premium_user = _license_check.is_premium() - verbose_proxy_logger.debug( - "litellm.proxy.proxy_server.py::startup() - PREMIUM USER value - {}".format( - premium_user - ) - ) - - ## COST TRACKING ## - cost_tracking() - - ## Error Tracking ## - error_tracking() - - ## UPDATE SLACK ALERTING ## - proxy_logging_obj.slack_alerting_instance.update_values(llm_router=llm_router) - - db_writer_client = HTTPHandler() - - ## UPDATE INTERNAL USAGE CACHE ## - proxy_logging_obj.update_values( - redis_cache=redis_usage_cache - ) # used by parallel request limiter for rate limiting keys across instances - - proxy_logging_obj._init_litellm_callbacks( - llm_router=llm_router - ) # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made - - if "daily_reports" in proxy_logging_obj.slack_alerting_instance.alert_types: - asyncio.create_task( - proxy_logging_obj.slack_alerting_instance._run_scheduled_daily_report( - llm_router=llm_router - ) - ) # RUN DAILY REPORT (if scheduled) - - ## JWT AUTH ## - if general_settings.get("litellm_jwtauth", None) is not None: - for k, v in general_settings["litellm_jwtauth"].items(): - if isinstance(v, str) and v.startswith("os.environ/"): - general_settings["litellm_jwtauth"][k] = get_secret(v) - litellm_jwtauth = LiteLLM_JWTAuth(**general_settings["litellm_jwtauth"]) - else: - litellm_jwtauth = LiteLLM_JWTAuth() - jwt_handler.update_environment( - prisma_client=prisma_client, - user_api_key_cache=user_api_key_cache, - litellm_jwtauth=litellm_jwtauth, - ) - - if use_background_health_checks: - asyncio.create_task( - _run_background_health_check() - ) # start the background health check coroutine. - - if prompt_injection_detection_obj is not None: - prompt_injection_detection_obj.update_environment(router=llm_router) - - verbose_proxy_logger.debug("prisma_client: %s", prisma_client) - if prisma_client is not None: - await prisma_client.connect() - - if prisma_client is not None and master_key is not None: + @classmethod + def _add_master_key_hash_to_db( + cls, + master_key: str, + prisma_client: PrismaClient, + litellm_proxy_admin_name: str, + general_settings: dict, + ): + """Adds master key hash to db for cost tracking""" if os.getenv("PROXY_ADMIN_ID", None) is not None: litellm_proxy_admin_name = os.getenv( "PROXY_ADMIN_ID", litellm_proxy_admin_name @@ -2951,7 +2856,9 @@ async def startup_event(): # noqa: PLR0915 ) asyncio.create_task(task_1) - if prisma_client is not None and litellm.max_budget > 0: + @classmethod + def _add_proxy_budget_to_db(cls, litellm_proxy_budget_name: str): + """Adds a global proxy budget to db""" if litellm.budget_duration is None: raise Exception( "budget_duration not set on Proxy. budget_duration is required to use max_budget." @@ -2977,8 +2884,18 @@ async def startup_event(): # noqa: PLR0915 ) ) - ### START BATCH WRITING DB + CHECKING NEW MODELS### - if prisma_client is not None: + @classmethod + async def initialize_scheduled_background_jobs( + cls, + general_settings: dict, + prisma_client: PrismaClient, + proxy_budget_rescheduler_min_time: int, + proxy_budget_rescheduler_max_time: int, + proxy_batch_write_at: int, + proxy_logging_obj: ProxyLogging, + store_model_in_db: bool, + ): + """Initializes scheduled background jobs""" scheduler = AsyncIOScheduler() interval = random.randint( proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time @@ -3067,6 +2984,165 @@ async def startup_event(): # noqa: PLR0915 scheduler.start() + @classmethod + def _setup_prisma_client( + cls, + database_url: Optional[str], + proxy_logging_obj: ProxyLogging, + user_api_key_cache: DualCache, + ) -> Optional[PrismaClient]: + """ + - Sets up prisma client + - Adds necessary views to proxy + """ + prisma_client: Optional[PrismaClient] = None + if database_url is not None: + try: + prisma_client = PrismaClient( + database_url=database_url, proxy_logging_obj=proxy_logging_obj + ) + except Exception as e: + raise e + + ## Add necessary views to proxy ## + asyncio.create_task( + prisma_client.check_view_exists() + ) # check if all necessary views exist. Don't block execution + + return prisma_client + + +@router.on_event("startup") +async def startup_event(): + global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db, premium_user, _license_check + import json + + init_verbose_loggers() + + ### LOAD MASTER KEY ### + # check if master key set in environment - load from there + master_key = get_secret("LITELLM_MASTER_KEY", None) # type: ignore + # check if DATABASE_URL in environment - load from there + if prisma_client is None: + _db_url: Optional[str] = get_secret("DATABASE_URL", None) # type: ignore + prisma_client = ProxyStartupEvent._setup_prisma_client( + database_url=_db_url, + proxy_logging_obj=proxy_logging_obj, + user_api_key_cache=user_api_key_cache, + ) + + ### LOAD CONFIG ### + worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore + env_config_yaml: Optional[str] = get_secret_str("CONFIG_FILE_PATH") + verbose_proxy_logger.debug("worker_config: %s", worker_config) + # check if it's a valid file path + if env_config_yaml is not None: + if os.path.isfile(env_config_yaml) and proxy_config.is_yaml( + config_file_path=env_config_yaml + ): + ( + llm_router, + llm_model_list, + general_settings, + ) = await proxy_config.load_config( + router=llm_router, config_file_path=env_config_yaml + ) + elif worker_config is not None: + if ( + isinstance(worker_config, str) + and os.path.isfile(worker_config) + and proxy_config.is_yaml(config_file_path=worker_config) + ): + ( + llm_router, + llm_model_list, + general_settings, + ) = await proxy_config.load_config( + router=llm_router, config_file_path=worker_config + ) + elif os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None and isinstance( + worker_config, str + ): + ( + llm_router, + llm_model_list, + general_settings, + ) = await proxy_config.load_config( + router=llm_router, config_file_path=worker_config + ) + elif isinstance(worker_config, dict): + await initialize(**worker_config) + else: + # if not, assume it's a json string + worker_config = json.loads(worker_config) + if isinstance(worker_config, dict): + await initialize(**worker_config) + + ## CHECK PREMIUM USER + verbose_proxy_logger.debug( + "litellm.proxy.proxy_server.py::startup() - CHECKING PREMIUM USER - {}".format( + premium_user + ) + ) + if premium_user is False: + premium_user = _license_check.is_premium() + + verbose_proxy_logger.debug( + "litellm.proxy.proxy_server.py::startup() - PREMIUM USER value - {}".format( + premium_user + ) + ) + + ProxyStartupEvent._initialize_startup_logging( + llm_router=llm_router, + proxy_logging_obj=proxy_logging_obj, + redis_usage_cache=redis_usage_cache, + ) + + ## JWT AUTH ## + ProxyStartupEvent._initialize_jwt_auth( + general_settings=general_settings, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + ) + + if use_background_health_checks: + asyncio.create_task( + _run_background_health_check() + ) # start the background health check coroutine. + + if prompt_injection_detection_obj is not None: # [TODO] - REFACTOR THIS + prompt_injection_detection_obj.update_environment(router=llm_router) + + verbose_proxy_logger.debug("prisma_client: %s", prisma_client) + if prisma_client is not None: + await prisma_client.connect() + + if prisma_client is not None and master_key is not None: + ProxyStartupEvent._add_master_key_hash_to_db( + master_key=master_key, + prisma_client=prisma_client, + litellm_proxy_admin_name=litellm_proxy_admin_name, + general_settings=general_settings, + ) + + if prisma_client is not None and litellm.max_budget > 0: + ProxyStartupEvent._add_proxy_budget_to_db( + litellm_proxy_budget_name=litellm_proxy_admin_name + ) + + ### START BATCH WRITING DB + CHECKING NEW MODELS### + if prisma_client is not None: + await ProxyStartupEvent.initialize_scheduled_background_jobs( + general_settings=general_settings, + prisma_client=prisma_client, + proxy_budget_rescheduler_min_time=proxy_budget_rescheduler_min_time, + proxy_budget_rescheduler_max_time=proxy_budget_rescheduler_max_time, + proxy_batch_write_at=proxy_batch_write_at, + proxy_logging_obj=proxy_logging_obj, + store_model_in_db=store_model_in_db, + ) + #### API ENDPOINTS #### @router.get( diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index d41cf2dfb..6089bf0c3 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -349,6 +349,31 @@ class ProxyLogging: ) self.premium_user = premium_user + def startup_event( + self, + llm_router: Optional[litellm.Router], + redis_usage_cache: Optional[RedisCache], + ): + """Initialize logging and alerting on proxy startup""" + ## UPDATE SLACK ALERTING ## + self.slack_alerting_instance.update_values(llm_router=llm_router) + + ## UPDATE INTERNAL USAGE CACHE ## + self.update_values( + redis_cache=redis_usage_cache + ) # used by parallel request limiter for rate limiting keys across instances + + self._init_litellm_callbacks( + llm_router=llm_router + ) # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made + + if "daily_reports" in self.slack_alerting_instance.alert_types: + asyncio.create_task( + self.slack_alerting_instance._run_scheduled_daily_report( + llm_router=llm_router + ) + ) # RUN DAILY REPORT (if scheduled) + def update_values( self, alerting: Optional[List] = None, diff --git a/tests/local_testing/test_proxy_server.py b/tests/local_testing/test_proxy_server.py index d76894ce6..e92d84c55 100644 --- a/tests/local_testing/test_proxy_server.py +++ b/tests/local_testing/test_proxy_server.py @@ -1894,3 +1894,27 @@ async def test_proxy_model_group_info_rerank(prisma_client): # asyncio.run(test()) # except Exception as e: # pytest.fail(f"An exception occurred - {str(e)}") + + +@pytest.mark.asyncio +async def test_proxy_server_prisma_setup(): + from litellm.proxy.proxy_server import ProxyStartupEvent + from litellm.proxy.utils import ProxyLogging + from litellm.caching import DualCache + + user_api_key_cache = DualCache() + + with patch.object( + litellm.proxy.proxy_server, "PrismaClient", new=MagicMock() + ) as mock_prisma_client: + mock_client = mock_prisma_client.return_value # This is the mocked instance + mock_client.check_view_exists = AsyncMock() # Mock the check_view_exists method + + ProxyStartupEvent._setup_prisma_client( + database_url=os.getenv("DATABASE_URL"), + proxy_logging_obj=ProxyLogging(user_api_key_cache=user_api_key_cache), + user_api_key_cache=user_api_key_cache, + ) + + await asyncio.sleep(1) + mock_client.check_view_exists.assert_called_once()