diff --git a/docs/my-website/docs/proxy/self_serve.md b/docs/my-website/docs/proxy/self_serve.md index 3de662da41..2fc17d952e 100644 --- a/docs/my-website/docs/proxy/self_serve.md +++ b/docs/my-website/docs/proxy/self_serve.md @@ -296,8 +296,6 @@ When you connect litellm to your SSO provider, litellm can auto-create teams. Us **Usage** -1. Set the default params for new teams - ```yaml showLineNumbers title="Default Params for new teams" litellm_settings: default_team_params: # Default Params to apply when litellm auto creates a team from SSO IDP provider diff --git a/litellm/constants.py b/litellm/constants.py index c8248f548a..12bfd17815 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -480,6 +480,7 @@ RESPONSE_FORMAT_TOOL_NAME = "json_tool_call" # default tool name used when conv ########################### Logging Callback Constants ########################### AZURE_STORAGE_MSFT_VERSION = "2019-07-07" +PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES = 5 MCP_TOOL_NAME_PREFIX = "mcp_tool" ########################### LiteLLM Proxy Specific Constants ########################### @@ -514,6 +515,7 @@ LITELLM_PROXY_ADMIN_NAME = "default_user_id" ########################### DB CRON JOB NAMES ########################### DB_SPEND_UPDATE_JOB_NAME = "db_spend_update_job" +PROMETHEUS_EMIT_BUDGET_METRICS_JOB_NAME = "prometheus_emit_budget_metrics_job" DEFAULT_CRON_JOB_LOCK_TTL_SECONDS = 60 # 1 minute PROXY_BUDGET_RESCHEDULER_MIN_TIME = 597 PROXY_BUDGET_RESCHEDULER_MAX_TIME = 605 diff --git a/litellm/integrations/prometheus.py b/litellm/integrations/prometheus.py index 6fba69d005..f61321e53d 100644 --- a/litellm/integrations/prometheus.py +++ b/litellm/integrations/prometheus.py @@ -1,10 +1,19 @@ # used for /metrics endpoint on LiteLLM Proxy #### What this does #### # On success, log events to Prometheus -import asyncio import sys from datetime import datetime, timedelta -from typing import Any, Awaitable, Callable, List, Literal, Optional, Tuple, cast +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + List, + Literal, + Optional, + Tuple, + cast, +) import litellm from litellm._logging import print_verbose, verbose_logger @@ -14,6 +23,11 @@ from litellm.types.integrations.prometheus import * from litellm.types.utils import StandardLoggingPayload from litellm.utils import get_end_user_id_for_cost_tracking +if TYPE_CHECKING: + from apscheduler.schedulers.asyncio import AsyncIOScheduler +else: + AsyncIOScheduler = Any + class PrometheusLogger(CustomLogger): # Class variables or attributes @@ -359,8 +373,6 @@ class PrometheusLogger(CustomLogger): label_name="litellm_requests_metric" ), ) - self._initialize_prometheus_startup_metrics() - except Exception as e: print_verbose(f"Got exception on init prometheus client {str(e)}") raise e @@ -988,9 +1000,9 @@ class PrometheusLogger(CustomLogger): ): try: verbose_logger.debug("setting remaining tokens requests metric") - standard_logging_payload: Optional[ - StandardLoggingPayload - ] = request_kwargs.get("standard_logging_object") + standard_logging_payload: Optional[StandardLoggingPayload] = ( + request_kwargs.get("standard_logging_object") + ) if standard_logging_payload is None: return @@ -1337,24 +1349,6 @@ class PrometheusLogger(CustomLogger): return max_budget - spend - def _initialize_prometheus_startup_metrics(self): - """ - Initialize prometheus startup metrics - - Helper to create tasks for initializing metrics that are required on startup - eg. remaining budget metrics - """ - if litellm.prometheus_initialize_budget_metrics is not True: - verbose_logger.debug("Prometheus: skipping budget metrics initialization") - return - - try: - if asyncio.get_running_loop(): - asyncio.create_task(self._initialize_remaining_budget_metrics()) - except RuntimeError as e: # no running event loop - verbose_logger.exception( - f"No running event loop - skipping budget metrics initialization: {str(e)}" - ) - async def _initialize_budget_metrics( self, data_fetch_function: Callable[..., Awaitable[Tuple[List[Any], Optional[int]]]], @@ -1475,12 +1469,41 @@ class PrometheusLogger(CustomLogger): data_type="keys", ) - async def _initialize_remaining_budget_metrics(self): + async def initialize_remaining_budget_metrics(self): """ - Initialize remaining budget metrics for all teams to avoid metric discrepancies. + Handler for initializing remaining budget metrics for all teams to avoid metric discrepancies. Runs when prometheus logger starts up. + + - If redis cache is available, we use the pod lock manager to acquire a lock and initialize the metrics. + - Ensures only one pod emits the metrics at a time. + - If redis cache is not available, we initialize the metrics directly. """ + from litellm.constants import PROMETHEUS_EMIT_BUDGET_METRICS_JOB_NAME + from litellm.proxy.proxy_server import proxy_logging_obj + + pod_lock_manager = proxy_logging_obj.db_spend_update_writer.pod_lock_manager + + # if using redis, ensure only one pod emits the metrics at a time + if pod_lock_manager and pod_lock_manager.redis_cache: + if await pod_lock_manager.acquire_lock( + cronjob_id=PROMETHEUS_EMIT_BUDGET_METRICS_JOB_NAME + ): + try: + await self._initialize_remaining_budget_metrics() + finally: + await pod_lock_manager.release_lock( + cronjob_id=PROMETHEUS_EMIT_BUDGET_METRICS_JOB_NAME + ) + else: + # if not using redis, initialize the metrics directly + await self._initialize_remaining_budget_metrics() + + async def _initialize_remaining_budget_metrics(self): + """ + Helper to initialize remaining budget metrics for all teams and API keys. + """ + verbose_logger.debug("Emitting key, team budget metrics....") await self._initialize_team_budget_metrics() await self._initialize_api_key_budget_metrics() @@ -1737,6 +1760,36 @@ class PrometheusLogger(CustomLogger): return (end_time - start_time).total_seconds() return None + @staticmethod + def initialize_budget_metrics_cron_job(scheduler: AsyncIOScheduler): + """ + Initialize budget metrics as a cron job. This job runs every `PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES` minutes. + + It emits the current remaining budget metrics for all Keys and Teams. + """ + from litellm.constants import PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES + from litellm.integrations.custom_logger import CustomLogger + from litellm.integrations.prometheus import PrometheusLogger + + prometheus_loggers: List[CustomLogger] = ( + litellm.logging_callback_manager.get_custom_loggers_for_type( + callback_type=PrometheusLogger + ) + ) + # we need to get the initialized prometheus logger instance(s) and call logger.initialize_remaining_budget_metrics() on them + verbose_logger.debug("found %s prometheus loggers", len(prometheus_loggers)) + if len(prometheus_loggers) > 0: + prometheus_logger = cast(PrometheusLogger, prometheus_loggers[0]) + verbose_logger.debug( + "Initializing remaining budget metrics as a cron job executing every %s minutes" + % PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES + ) + scheduler.add_job( + prometheus_logger.initialize_remaining_budget_metrics, + "interval", + minutes=PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES, + ) + @staticmethod def _mount_metrics_endpoint(premium_user: bool): """ diff --git a/litellm/proxy/db/db_spend_update_writer.py b/litellm/proxy/db/db_spend_update_writer.py index b32dc5c691..12ae51822c 100644 --- a/litellm/proxy/db/db_spend_update_writer.py +++ b/litellm/proxy/db/db_spend_update_writer.py @@ -53,7 +53,7 @@ class DBSpendUpdateWriter: ): self.redis_cache = redis_cache self.redis_update_buffer = RedisUpdateBuffer(redis_cache=self.redis_cache) - self.pod_lock_manager = PodLockManager(cronjob_id=DB_SPEND_UPDATE_JOB_NAME) + self.pod_lock_manager = PodLockManager() self.spend_update_queue = SpendUpdateQueue() self.daily_spend_update_queue = DailySpendUpdateQueue() @@ -383,7 +383,9 @@ class DBSpendUpdateWriter: ) # Only commit from redis to db if this pod is the leader - if await self.pod_lock_manager.acquire_lock(): + if await self.pod_lock_manager.acquire_lock( + cronjob_id=DB_SPEND_UPDATE_JOB_NAME, + ): verbose_proxy_logger.debug("acquired lock for spend updates") try: @@ -411,7 +413,9 @@ class DBSpendUpdateWriter: except Exception as e: verbose_proxy_logger.error(f"Error committing spend updates: {e}") finally: - await self.pod_lock_manager.release_lock() + await self.pod_lock_manager.release_lock( + cronjob_id=DB_SPEND_UPDATE_JOB_NAME, + ) async def _commit_spend_updates_to_db_without_redis_buffer( self, diff --git a/litellm/proxy/db/db_transaction_queue/pod_lock_manager.py b/litellm/proxy/db/db_transaction_queue/pod_lock_manager.py index cb4a43a802..be3be64546 100644 --- a/litellm/proxy/db/db_transaction_queue/pod_lock_manager.py +++ b/litellm/proxy/db/db_transaction_queue/pod_lock_manager.py @@ -21,18 +21,18 @@ class PodLockManager: Ensures that only one pod can run a cron job at a time. """ - def __init__(self, cronjob_id: str, redis_cache: Optional[RedisCache] = None): + def __init__(self, redis_cache: Optional[RedisCache] = None): self.pod_id = str(uuid.uuid4()) - self.cronjob_id = cronjob_id self.redis_cache = redis_cache - # Define a unique key for this cronjob lock in Redis. - self.lock_key = PodLockManager.get_redis_lock_key(cronjob_id) @staticmethod def get_redis_lock_key(cronjob_id: str) -> str: return f"cronjob_lock:{cronjob_id}" - async def acquire_lock(self) -> Optional[bool]: + async def acquire_lock( + self, + cronjob_id: str, + ) -> Optional[bool]: """ Attempt to acquire the lock for a specific cron job using Redis. Uses the SET command with NX and EX options to ensure atomicity. @@ -44,12 +44,13 @@ class PodLockManager: verbose_proxy_logger.debug( "Pod %s attempting to acquire Redis lock for cronjob_id=%s", self.pod_id, - self.cronjob_id, + cronjob_id, ) # Try to set the lock key with the pod_id as its value, only if it doesn't exist (NX) # and with an expiration (EX) to avoid deadlocks. + lock_key = PodLockManager.get_redis_lock_key(cronjob_id) acquired = await self.redis_cache.async_set_cache( - self.lock_key, + lock_key, self.pod_id, nx=True, ttl=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS, @@ -58,13 +59,13 @@ class PodLockManager: verbose_proxy_logger.info( "Pod %s successfully acquired Redis lock for cronjob_id=%s", self.pod_id, - self.cronjob_id, + cronjob_id, ) return True else: # Check if the current pod already holds the lock - current_value = await self.redis_cache.async_get_cache(self.lock_key) + current_value = await self.redis_cache.async_get_cache(lock_key) if current_value is not None: if isinstance(current_value, bytes): current_value = current_value.decode("utf-8") @@ -72,18 +73,21 @@ class PodLockManager: verbose_proxy_logger.info( "Pod %s already holds the Redis lock for cronjob_id=%s", self.pod_id, - self.cronjob_id, + cronjob_id, ) - self._emit_acquired_lock_event(self.cronjob_id, self.pod_id) + self._emit_acquired_lock_event(cronjob_id, self.pod_id) return True return False except Exception as e: verbose_proxy_logger.error( - f"Error acquiring Redis lock for {self.cronjob_id}: {e}" + f"Error acquiring Redis lock for {cronjob_id}: {e}" ) return False - async def release_lock(self): + async def release_lock( + self, + cronjob_id: str, + ): """ Release the lock if the current pod holds it. Uses get and delete commands to ensure that only the owner can release the lock. @@ -92,46 +96,52 @@ class PodLockManager: verbose_proxy_logger.debug("redis_cache is None, skipping release_lock") return try: + cronjob_id = cronjob_id verbose_proxy_logger.debug( "Pod %s attempting to release Redis lock for cronjob_id=%s", self.pod_id, - self.cronjob_id, + cronjob_id, ) - current_value = await self.redis_cache.async_get_cache(self.lock_key) + lock_key = PodLockManager.get_redis_lock_key(cronjob_id) + + current_value = await self.redis_cache.async_get_cache(lock_key) if current_value is not None: if isinstance(current_value, bytes): current_value = current_value.decode("utf-8") if current_value == self.pod_id: - result = await self.redis_cache.async_delete_cache(self.lock_key) + result = await self.redis_cache.async_delete_cache(lock_key) if result == 1: verbose_proxy_logger.info( "Pod %s successfully released Redis lock for cronjob_id=%s", self.pod_id, - self.cronjob_id, + cronjob_id, + ) + self._emit_released_lock_event( + cronjob_id=cronjob_id, + pod_id=self.pod_id, ) - self._emit_released_lock_event(self.cronjob_id, self.pod_id) else: verbose_proxy_logger.debug( "Pod %s failed to release Redis lock for cronjob_id=%s", self.pod_id, - self.cronjob_id, + cronjob_id, ) else: verbose_proxy_logger.debug( "Pod %s cannot release Redis lock for cronjob_id=%s because it is held by pod %s", self.pod_id, - self.cronjob_id, + cronjob_id, current_value, ) else: verbose_proxy_logger.debug( "Pod %s attempted to release Redis lock for cronjob_id=%s, but no lock was found", self.pod_id, - self.cronjob_id, + cronjob_id, ) except Exception as e: verbose_proxy_logger.error( - f"Error releasing Redis lock for {self.cronjob_id}: {e}" + f"Error releasing Redis lock for {cronjob_id}: {e}" ) @staticmethod diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index e6588cb5d2..1804aa9f56 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -469,9 +469,6 @@ async def auth_callback(request: Request): # noqa: PLR0915 result=result, user_info=user_info, user_email=user_email, - user_id_models=user_id_models, - max_internal_user_budget=max_internal_user_budget, - internal_user_budget_duration=internal_user_budget_duration, user_defined_values=user_defined_values, prisma_client=prisma_client, ) @@ -832,37 +829,20 @@ class SSOAuthenticationHandler: result: Optional[Union[CustomOpenID, OpenID, dict]], user_info: Optional[Union[NewUserResponse, LiteLLM_UserTable]], user_email: Optional[str], - user_id_models: List[str], - max_internal_user_budget: Optional[float], - internal_user_budget_duration: Optional[str], user_defined_values: Optional[SSOUserDefinedValues], prisma_client: PrismaClient, ): """ Connects the SSO Users to the User Table in LiteLLM DB - - If user on LiteLLM DB, update the user_id with the SSO user_id + - If user on LiteLLM DB, update the user_email with the SSO user_email - If user not on LiteLLM DB, insert the user into LiteLLM DB """ try: if user_info is not None: user_id = user_info.user_id - user_defined_values = SSOUserDefinedValues( - models=getattr(user_info, "models", user_id_models), - user_id=user_info.user_id or "", - user_email=getattr(user_info, "user_email", user_email), - user_role=getattr(user_info, "user_role", None), - max_budget=getattr( - user_info, "max_budget", max_internal_user_budget - ), - budget_duration=getattr( - user_info, "budget_duration", internal_user_budget_duration - ), - ) - - # update id await prisma_client.db.litellm_usertable.update_many( - where={"user_email": user_email}, data={"user_id": user_id} # type: ignore + where={"user_id": user_id}, data={"user_email": user_email} ) else: verbose_proxy_logger.info( @@ -1075,7 +1055,7 @@ class MicrosoftSSOHandler: response = response or {} verbose_proxy_logger.debug(f"Microsoft SSO Callback Response: {response}") openid_response = CustomOpenID( - email=response.get("mail"), + email=response.get("userPrincipalName") or response.get("mail"), display_name=response.get("displayName"), provider="microsoft", id=response.get("id"), diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index aa0a81cde5..847ca7ce56 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -10,7 +10,42 @@ model_list: api_key: fake-key litellm_settings: - default_team_params: # Default Params to apply when litellm auto creates a team from SSO IDP provider - max_budget: 100 # Optional[float], optional): $100 budget for the team - budget_duration: 30d # Optional[str], optional): 30 days budget_duration for the team - models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by the team \ No newline at end of file + prometheus_initialize_budget_metrics: true + callbacks: ["prometheus"] + +mcp_tools: + - name: "get_current_time" + description: "Get the current time" + input_schema: { + "type": "object", + "properties": { + "format": { + "type": "string", + "description": "The format of the time to return", + "enum": ["short"] + } + } + } + handler: "mcp_tools.get_current_time" + - name: "get_current_date" + description: "Get the current date" + input_schema: { + "type": "object", + "properties": { + "format": { + "type": "string", + "description": "The format of the date to return", + "enum": ["short"] + } + } + } + handler: "mcp_tools.get_current_date" + default_team_settings: + - team_id: test_dev + success_callback: ["langfuse", "s3"] + langfuse_secret: secret-test-key + langfuse_public_key: public-test-key + - team_id: my_workflows + success_callback: ["langfuse", "s3"] + langfuse_secret: secret-workflows-key + langfuse_public_key: public-workflows-key diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index ddfb7118d7..84b515f405 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -803,9 +803,9 @@ model_max_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter( dual_cache=user_api_key_cache ) litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter) -redis_usage_cache: Optional[ - RedisCache -] = None # redis cache used for tracking spend, tpm/rpm limits +redis_usage_cache: Optional[RedisCache] = ( + None # redis cache used for tracking spend, tpm/rpm limits +) user_custom_auth = None user_custom_key_generate = None user_custom_sso = None @@ -1131,9 +1131,9 @@ async def update_cache( # noqa: PLR0915 _id = "team_id:{}".format(team_id) try: # Fetch the existing cost for the given user - existing_spend_obj: Optional[ - LiteLLM_TeamTable - ] = await user_api_key_cache.async_get_cache(key=_id) + existing_spend_obj: Optional[LiteLLM_TeamTable] = ( + await user_api_key_cache.async_get_cache(key=_id) + ) if existing_spend_obj is None: # do nothing if team not in api key cache return @@ -2812,9 +2812,9 @@ async def initialize( # noqa: PLR0915 user_api_base = api_base dynamic_config[user_model]["api_base"] = api_base if api_version: - os.environ[ - "AZURE_API_VERSION" - ] = api_version # set this for azure - litellm can read this from the env + os.environ["AZURE_API_VERSION"] = ( + api_version # set this for azure - litellm can read this from the env + ) if max_tokens: # model-specific param dynamic_config[user_model]["max_tokens"] = max_tokens if temperature: # model-specific param @@ -3191,6 +3191,11 @@ class ProxyStartupEvent: ) await proxy_logging_obj.slack_alerting_instance.send_fallback_stats_from_prometheus() + if litellm.prometheus_initialize_budget_metrics is True: + from litellm.integrations.prometheus import PrometheusLogger + + PrometheusLogger.initialize_budget_metrics_cron_job(scheduler=scheduler) + scheduler.start() @classmethod @@ -7753,9 +7758,9 @@ async def get_config_list( hasattr(sub_field_info, "description") and sub_field_info.description is not None ): - nested_fields[ - idx - ].field_description = sub_field_info.description + nested_fields[idx].field_description = ( + sub_field_info.description + ) idx += 1 _stored_in_db = None diff --git a/tests/litellm/integrations/test_prometheus.py b/tests/litellm/integrations/test_prometheus.py new file mode 100644 index 0000000000..464477f019 --- /dev/null +++ b/tests/litellm/integrations/test_prometheus.py @@ -0,0 +1,44 @@ +""" +Mock prometheus unit tests, these don't rely on LLM API calls +""" + +import json +import os +import sys + +import pytest +from fastapi.testclient import TestClient + +sys.path.insert( + 0, os.path.abspath("../../..") +) # Adds the parent directory to the system path + +from apscheduler.schedulers.asyncio import AsyncIOScheduler + +import litellm +from litellm.constants import PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES +from litellm.integrations.prometheus import PrometheusLogger + + +def test_initialize_budget_metrics_cron_job(): + # Create a scheduler + scheduler = AsyncIOScheduler() + + # Create and register a PrometheusLogger + prometheus_logger = PrometheusLogger() + litellm.callbacks = [prometheus_logger] + + # Initialize the cron job + PrometheusLogger.initialize_budget_metrics_cron_job(scheduler) + + # Verify that a job was added to the scheduler + jobs = scheduler.get_jobs() + assert len(jobs) == 1 + + # Verify job properties + job = jobs[0] + assert ( + job.trigger.interval.total_seconds() / 60 + == PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES + ) + assert job.func.__name__ == "initialize_remaining_budget_metrics" diff --git a/tests/litellm/proxy/db/db_transaction_queue/test_pod_lock_manager.py b/tests/litellm/proxy/db/db_transaction_queue/test_pod_lock_manager.py index 697d985dc9..e83fd75c3a 100644 --- a/tests/litellm/proxy/db/db_transaction_queue/test_pod_lock_manager.py +++ b/tests/litellm/proxy/db/db_transaction_queue/test_pod_lock_manager.py @@ -29,7 +29,7 @@ def mock_redis(): @pytest.fixture def pod_lock_manager(mock_redis): - return PodLockManager(cronjob_id="test_job", redis_cache=mock_redis) + return PodLockManager(redis_cache=mock_redis) @pytest.mark.asyncio @@ -40,12 +40,15 @@ async def test_acquire_lock_success(pod_lock_manager, mock_redis): # Mock successful acquisition (SET NX returns True) mock_redis.async_set_cache.return_value = True - result = await pod_lock_manager.acquire_lock() + result = await pod_lock_manager.acquire_lock( + cronjob_id="test_job", + ) assert result == True # Verify set_cache was called with correct parameters + lock_key = pod_lock_manager.get_redis_lock_key(cronjob_id="test_job") mock_redis.async_set_cache.assert_called_once_with( - pod_lock_manager.lock_key, + lock_key, pod_lock_manager.pod_id, nx=True, ttl=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS, @@ -62,13 +65,16 @@ async def test_acquire_lock_existing_active(pod_lock_manager, mock_redis): # Mock get_cache to return a different pod's ID mock_redis.async_get_cache.return_value = "different_pod_id" - result = await pod_lock_manager.acquire_lock() + result = await pod_lock_manager.acquire_lock( + cronjob_id="test_job", + ) assert result == False # Verify set_cache was called mock_redis.async_set_cache.assert_called_once() # Verify get_cache was called to check existing lock - mock_redis.async_get_cache.assert_called_once_with(pod_lock_manager.lock_key) + lock_key = pod_lock_manager.get_redis_lock_key(cronjob_id="test_job") + mock_redis.async_get_cache.assert_called_once_with(lock_key) @pytest.mark.asyncio @@ -89,7 +95,9 @@ async def test_acquire_lock_expired(pod_lock_manager, mock_redis): # Then set succeeds on retry (simulating key expiring between checks) mock_redis.async_set_cache.side_effect = [False, True] - result = await pod_lock_manager.acquire_lock() + result = await pod_lock_manager.acquire_lock( + cronjob_id="test_job", + ) assert result == False # First attempt fails # Reset mock for a second attempt @@ -97,7 +105,9 @@ async def test_acquire_lock_expired(pod_lock_manager, mock_redis): mock_redis.async_set_cache.return_value = True # Try again (simulating the lock expired) - result = await pod_lock_manager.acquire_lock() + result = await pod_lock_manager.acquire_lock( + cronjob_id="test_job", + ) assert result == True # Verify set_cache was called again @@ -114,12 +124,15 @@ async def test_release_lock_success(pod_lock_manager, mock_redis): # Mock successful deletion mock_redis.async_delete_cache.return_value = 1 - await pod_lock_manager.release_lock() + await pod_lock_manager.release_lock( + cronjob_id="test_job", + ) # Verify get_cache was called - mock_redis.async_get_cache.assert_called_once_with(pod_lock_manager.lock_key) + lock_key = pod_lock_manager.get_redis_lock_key(cronjob_id="test_job") + mock_redis.async_get_cache.assert_called_once_with(lock_key) # Verify delete_cache was called - mock_redis.async_delete_cache.assert_called_once_with(pod_lock_manager.lock_key) + mock_redis.async_delete_cache.assert_called_once_with(lock_key) @pytest.mark.asyncio @@ -130,10 +143,13 @@ async def test_release_lock_different_pod(pod_lock_manager, mock_redis): # Mock get_cache to return a different pod's ID mock_redis.async_get_cache.return_value = "different_pod_id" - await pod_lock_manager.release_lock() + await pod_lock_manager.release_lock( + cronjob_id="test_job", + ) # Verify get_cache was called - mock_redis.async_get_cache.assert_called_once_with(pod_lock_manager.lock_key) + lock_key = pod_lock_manager.get_redis_lock_key(cronjob_id="test_job") + mock_redis.async_get_cache.assert_called_once_with(lock_key) # Verify delete_cache was NOT called mock_redis.async_delete_cache.assert_not_called() @@ -146,10 +162,13 @@ async def test_release_lock_no_lock(pod_lock_manager, mock_redis): # Mock get_cache to return None (no lock) mock_redis.async_get_cache.return_value = None - await pod_lock_manager.release_lock() + await pod_lock_manager.release_lock( + cronjob_id="test_job", + ) # Verify get_cache was called - mock_redis.async_get_cache.assert_called_once_with(pod_lock_manager.lock_key) + lock_key = pod_lock_manager.get_redis_lock_key(cronjob_id="test_job") + mock_redis.async_get_cache.assert_called_once_with(lock_key) # Verify delete_cache was NOT called mock_redis.async_delete_cache.assert_not_called() @@ -159,13 +178,20 @@ async def test_redis_none(monkeypatch): """ Test behavior when redis_cache is None """ - pod_lock_manager = PodLockManager(cronjob_id="test_job", redis_cache=None) + pod_lock_manager = PodLockManager(redis_cache=None) # Test acquire_lock with None redis_cache - assert await pod_lock_manager.acquire_lock() is None + assert ( + await pod_lock_manager.acquire_lock( + cronjob_id="test_job", + ) + is None + ) # Test release_lock with None redis_cache (should not raise exception) - await pod_lock_manager.release_lock() + await pod_lock_manager.release_lock( + cronjob_id="test_job", + ) @pytest.mark.asyncio @@ -179,7 +205,9 @@ async def test_redis_error_handling(pod_lock_manager, mock_redis): mock_redis.async_delete_cache.side_effect = Exception("Redis error") # Test acquire_lock error handling - result = await pod_lock_manager.acquire_lock() + result = await pod_lock_manager.acquire_lock( + cronjob_id="test_job", + ) assert result == False # Reset side effect for get_cache for the release test @@ -187,7 +215,9 @@ async def test_redis_error_handling(pod_lock_manager, mock_redis): mock_redis.async_get_cache.return_value = pod_lock_manager.pod_id # Test release_lock error handling (should not raise exception) - await pod_lock_manager.release_lock() + await pod_lock_manager.release_lock( + cronjob_id="test_job", + ) @pytest.mark.asyncio @@ -200,14 +230,18 @@ async def test_bytes_handling(pod_lock_manager, mock_redis): # Mock get_cache to return bytes mock_redis.async_get_cache.return_value = pod_lock_manager.pod_id.encode("utf-8") - result = await pod_lock_manager.acquire_lock() + result = await pod_lock_manager.acquire_lock( + cronjob_id="test_job", + ) assert result == True # Reset for release test mock_redis.async_get_cache.return_value = pod_lock_manager.pod_id.encode("utf-8") mock_redis.async_delete_cache.return_value = 1 - await pod_lock_manager.release_lock() + await pod_lock_manager.release_lock( + cronjob_id="test_job", + ) mock_redis.async_delete_cache.assert_called_once() @@ -217,15 +251,17 @@ async def test_concurrent_lock_acquisition_simulation(): Simulate multiple pods trying to acquire the lock simultaneously """ mock_redis = MockRedisCache() - pod1 = PodLockManager(cronjob_id="test_job", redis_cache=mock_redis) - pod2 = PodLockManager(cronjob_id="test_job", redis_cache=mock_redis) - pod3 = PodLockManager(cronjob_id="test_job", redis_cache=mock_redis) + pod1 = PodLockManager(redis_cache=mock_redis) + pod2 = PodLockManager(redis_cache=mock_redis) + pod3 = PodLockManager(redis_cache=mock_redis) # Simulate first pod getting the lock mock_redis.async_set_cache.return_value = True # First pod should get the lock - result1 = await pod1.acquire_lock() + result1 = await pod1.acquire_lock( + cronjob_id="test_job", + ) assert result1 == True # Simulate other pods failing to get the lock @@ -233,8 +269,12 @@ async def test_concurrent_lock_acquisition_simulation(): mock_redis.async_get_cache.return_value = pod1.pod_id # Other pods should fail to acquire - result2 = await pod2.acquire_lock() - result3 = await pod3.acquire_lock() + result2 = await pod2.acquire_lock( + cronjob_id="test_job", + ) + result3 = await pod3.acquire_lock( + cronjob_id="test_job", + ) # Since other pods don't have the lock, they should get False assert result2 == False @@ -246,14 +286,16 @@ async def test_lock_takeover_race_condition(mock_redis): """ Test scenario where multiple pods try to take over an expired lock using Redis """ - pod1 = PodLockManager(cronjob_id="test_job", redis_cache=mock_redis) - pod2 = PodLockManager(cronjob_id="test_job", redis_cache=mock_redis) + pod1 = PodLockManager(redis_cache=mock_redis) + pod2 = PodLockManager(redis_cache=mock_redis) # Simulate first pod's acquisition succeeding mock_redis.async_set_cache.return_value = True # First pod should successfully acquire - result1 = await pod1.acquire_lock() + result1 = await pod1.acquire_lock( + cronjob_id="test_job", + ) assert result1 == True # Simulate race condition: second pod tries but fails @@ -261,5 +303,7 @@ async def test_lock_takeover_race_condition(mock_redis): mock_redis.async_get_cache.return_value = pod1.pod_id # Second pod should fail to acquire - result2 = await pod2.acquire_lock() + result2 = await pod2.acquire_lock( + cronjob_id="test_job", + ) assert result2 == False diff --git a/tests/litellm/proxy/management_endpoints/test_ui_sso.py b/tests/litellm/proxy/management_endpoints/test_ui_sso.py index 3ba1cb0d73..6683c2cc91 100644 --- a/tests/litellm/proxy/management_endpoints/test_ui_sso.py +++ b/tests/litellm/proxy/management_endpoints/test_ui_sso.py @@ -21,6 +21,7 @@ from litellm.proxy.management_endpoints.types import CustomOpenID from litellm.proxy.management_endpoints.ui_sso import ( GoogleSSOHandler, MicrosoftSSOHandler, + SSOAuthenticationHandler, ) from litellm.types.proxy.management_endpoints.ui_sso import ( MicrosoftGraphAPIUserGroupDirectoryObject, @@ -29,6 +30,37 @@ from litellm.types.proxy.management_endpoints.ui_sso import ( ) +def test_microsoft_sso_handler_openid_from_response_user_principal_name(): + # Arrange + # Create a mock response similar to what Microsoft SSO would return + mock_response = { + "userPrincipalName": "test@example.com", + "displayName": "Test User", + "id": "user123", + "givenName": "Test", + "surname": "User", + "some_other_field": "value", + } + expected_team_ids = ["team1", "team2"] + # Act + # Call the method being tested + result = MicrosoftSSOHandler.openid_from_response( + response=mock_response, team_ids=expected_team_ids + ) + + # Assert + + # Check that the result is a CustomOpenID object with the expected values + assert isinstance(result, CustomOpenID) + assert result.email == "test@example.com" + assert result.display_name == "Test User" + assert result.provider == "microsoft" + assert result.id == "user123" + assert result.first_name == "Test" + assert result.last_name == "User" + assert result.team_ids == expected_team_ids + + def test_microsoft_sso_handler_openid_from_response(): # Arrange # Create a mock response similar to what Microsoft SSO would return diff --git a/tests/logging_callback_tests/test_prometheus_unit_tests.py b/tests/logging_callback_tests/test_prometheus_unit_tests.py index ddfce710d7..0b58bc7aaf 100644 --- a/tests/logging_callback_tests/test_prometheus_unit_tests.py +++ b/tests/logging_callback_tests/test_prometheus_unit_tests.py @@ -39,7 +39,7 @@ import time @pytest.fixture -def prometheus_logger(): +def prometheus_logger() -> PrometheusLogger: collectors = list(REGISTRY._collector_to_names.keys()) for collector in collectors: REGISTRY.unregister(collector) @@ -1212,24 +1212,6 @@ async def test_initialize_remaining_budget_metrics_exception_handling( prometheus_logger.litellm_remaining_api_key_budget_metric.assert_not_called() -def test_initialize_prometheus_startup_metrics_no_loop(prometheus_logger): - """ - Test that _initialize_prometheus_startup_metrics handles case when no event loop exists - """ - # Mock asyncio.get_running_loop to raise RuntimeError - litellm.prometheus_initialize_budget_metrics = True - with patch( - "asyncio.get_running_loop", side_effect=RuntimeError("No running event loop") - ), patch("litellm._logging.verbose_logger.exception") as mock_logger: - - # Call the function - prometheus_logger._initialize_prometheus_startup_metrics() - - # Verify the error was logged - mock_logger.assert_called_once() - assert "No running event loop" in mock_logger.call_args[0][0] - - @pytest.mark.asyncio(scope="session") async def test_initialize_api_key_budget_metrics(prometheus_logger): """ diff --git a/tests/proxy_unit_tests/test_e2e_pod_lock_manager.py b/tests/proxy_unit_tests/test_e2e_pod_lock_manager.py index 652b1838ac..061da8c186 100644 --- a/tests/proxy_unit_tests/test_e2e_pod_lock_manager.py +++ b/tests/proxy_unit_tests/test_e2e_pod_lock_manager.py @@ -141,10 +141,12 @@ async def setup_db_connection(prisma_client): async def test_pod_lock_acquisition_when_no_active_lock(): """Test if a pod can acquire a lock when no lock is active""" cronjob_id = str(uuid.uuid4()) - lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) + lock_manager = PodLockManager(redis_cache=global_redis_cache) # Attempt to acquire lock - result = await lock_manager.acquire_lock() + result = await lock_manager.acquire_lock( + cronjob_id=cronjob_id, + ) assert result == True, "Pod should be able to acquire lock when no lock exists" @@ -161,13 +163,19 @@ async def test_pod_lock_acquisition_after_completion(): """Test if a new pod can acquire lock after previous pod completes""" cronjob_id = str(uuid.uuid4()) # First pod acquires and releases lock - first_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) - await first_lock_manager.acquire_lock() - await first_lock_manager.release_lock() + first_lock_manager = PodLockManager(redis_cache=global_redis_cache) + await first_lock_manager.acquire_lock( + cronjob_id=cronjob_id, + ) + await first_lock_manager.release_lock( + cronjob_id=cronjob_id, + ) # Second pod attempts to acquire lock - second_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) - result = await second_lock_manager.acquire_lock() + second_lock_manager = PodLockManager(redis_cache=global_redis_cache) + result = await second_lock_manager.acquire_lock( + cronjob_id=cronjob_id, + ) assert result == True, "Second pod should acquire lock after first pod releases it" @@ -182,15 +190,21 @@ async def test_pod_lock_acquisition_after_expiry(): """Test if a new pod can acquire lock after previous pod's lock expires""" cronjob_id = str(uuid.uuid4()) # First pod acquires lock - first_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) - await first_lock_manager.acquire_lock() + first_lock_manager = PodLockManager(redis_cache=global_redis_cache) + await first_lock_manager.acquire_lock( + cronjob_id=cronjob_id, + ) # release the lock from the first pod - await first_lock_manager.release_lock() + await first_lock_manager.release_lock( + cronjob_id=cronjob_id, + ) # Second pod attempts to acquire lock - second_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) - result = await second_lock_manager.acquire_lock() + second_lock_manager = PodLockManager(redis_cache=global_redis_cache) + result = await second_lock_manager.acquire_lock( + cronjob_id=cronjob_id, + ) assert ( result == True @@ -206,11 +220,15 @@ async def test_pod_lock_acquisition_after_expiry(): async def test_pod_lock_release(): """Test if a pod can successfully release its lock""" cronjob_id = str(uuid.uuid4()) - lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) + lock_manager = PodLockManager(redis_cache=global_redis_cache) # Acquire and then release lock - await lock_manager.acquire_lock() - await lock_manager.release_lock() + await lock_manager.acquire_lock( + cronjob_id=cronjob_id, + ) + await lock_manager.release_lock( + cronjob_id=cronjob_id, + ) # Verify in redis lock_key = PodLockManager.get_redis_lock_key(cronjob_id) @@ -224,15 +242,21 @@ async def test_concurrent_lock_acquisition(): cronjob_id = str(uuid.uuid4()) # Create multiple lock managers simulating different pods - lock_manager1 = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) - lock_manager2 = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) - lock_manager3 = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) + lock_manager1 = PodLockManager(redis_cache=global_redis_cache) + lock_manager2 = PodLockManager(redis_cache=global_redis_cache) + lock_manager3 = PodLockManager(redis_cache=global_redis_cache) # Try to acquire locks concurrently results = await asyncio.gather( - lock_manager1.acquire_lock(), - lock_manager2.acquire_lock(), - lock_manager3.acquire_lock(), + lock_manager1.acquire_lock( + cronjob_id=cronjob_id, + ), + lock_manager2.acquire_lock( + cronjob_id=cronjob_id, + ), + lock_manager3.acquire_lock( + cronjob_id=cronjob_id, + ), ) # Only one should succeed @@ -254,7 +278,7 @@ async def test_concurrent_lock_acquisition(): async def test_lock_acquisition_with_expired_ttl(): """Test that a pod can acquire a lock when existing lock has expired TTL""" cronjob_id = str(uuid.uuid4()) - first_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) + first_lock_manager = PodLockManager(redis_cache=global_redis_cache) # First pod acquires lock with a very short TTL to simulate expiration short_ttl = 1 # 1 second @@ -269,8 +293,10 @@ async def test_lock_acquisition_with_expired_ttl(): await asyncio.sleep(short_ttl + 0.5) # Wait slightly longer than the TTL # Second pod tries to acquire without explicit release - second_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) - result = await second_lock_manager.acquire_lock() + second_lock_manager = PodLockManager(redis_cache=global_redis_cache) + result = await second_lock_manager.acquire_lock( + cronjob_id=cronjob_id, + ) assert result == True, "Should acquire lock when existing lock has expired TTL" @@ -286,7 +312,7 @@ async def test_release_expired_lock(): cronjob_id = str(uuid.uuid4()) # First pod acquires lock with a very short TTL - first_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) + first_lock_manager = PodLockManager(redis_cache=global_redis_cache) short_ttl = 1 # 1 second lock_key = PodLockManager.get_redis_lock_key(cronjob_id) await global_redis_cache.async_set_cache( @@ -299,11 +325,15 @@ async def test_release_expired_lock(): await asyncio.sleep(short_ttl + 0.5) # Wait slightly longer than the TTL # Second pod acquires the lock - second_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) - await second_lock_manager.acquire_lock() + second_lock_manager = PodLockManager(redis_cache=global_redis_cache) + await second_lock_manager.acquire_lock( + cronjob_id=cronjob_id, + ) # First pod attempts to release its lock - await first_lock_manager.release_lock() + await first_lock_manager.release_lock( + cronjob_id=cronjob_id, + ) # Verify that second pod's lock is still active lock_record = await global_redis_cache.async_get_cache(lock_key)