From 90d862b0416ba8de4931a4602aeb77a1f0af966f Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 10 Apr 2025 16:58:28 -0700 Subject: [PATCH 1/3] [Feat SSO] - Allow admins to set `default_team_params` to have default params for when litellm SSO creates default teams (#9895) * add default_team_params as a config.yaml setting * create_litellm_team_from_sso_group * test_default_team_params * test_create_team_without_default_params * docs default team settings --- docs/my-website/docs/proxy/self_serve.md | 34 +++++- litellm/__init__.py | 100 +++++++++-------- litellm/proxy/management_endpoints/ui_sso.py | 105 +++++++++++------- .../proxy/management_endpoints/test_ui_sso.py | 99 ++++++++++++++++- 4 files changed, 252 insertions(+), 86 deletions(-) diff --git a/docs/my-website/docs/proxy/self_serve.md b/docs/my-website/docs/proxy/self_serve.md index d630c8e7f3..2fc17d952e 100644 --- a/docs/my-website/docs/proxy/self_serve.md +++ b/docs/my-website/docs/proxy/self_serve.md @@ -207,9 +207,14 @@ This walks through setting up sso auto-add for **Microsoft Entra ID** Follow along this video for a walkthrough of how to set this up with Microsoft Entra ID - +
+
+ +**Next steps** + +1. [Set default params for new teams auto-created from SSO](#set-default-params-for-new-teams) ### Debugging SSO JWT fields @@ -279,6 +284,26 @@ This budget does not apply to keys created under non-default teams. [**Go Here**](./team_budgets.md) +### Set default params for new teams + +When you connect litellm to your SSO provider, litellm can auto-create teams. Use this to set the default `models`, `max_budget`, `budget_duration` for these auto-created teams. + +**How it works** + +1. When litellm fetches `groups` from your SSO provider, it will check if the corresponding group_id exists as a `team_id` in litellm. +2. If the team_id does not exist, litellm will auto-create a team with the default params you've set. +3. If the team_id already exist, litellm will not apply any settings on the team. + +**Usage** + +```yaml showLineNumbers title="Default Params for new teams" +litellm_settings: + default_team_params: # Default Params to apply when litellm auto creates a team from SSO IDP provider + max_budget: 100 # Optional[float], optional): $100 budget for the team + budget_duration: 30d # Optional[str], optional): 30 days budget_duration for the team + models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by the team +``` + ### Restrict Users from creating personal keys @@ -290,7 +315,7 @@ This will also prevent users from using their session tokens on the test keys ch ## **All Settings for Self Serve / SSO Flow** -```yaml +```yaml showLineNumbers title="All Settings for Self Serve / SSO Flow" litellm_settings: max_internal_user_budget: 10 # max budget for internal users internal_user_budget_duration: "1mo" # reset every month @@ -300,6 +325,11 @@ litellm_settings: max_budget: 100 # Optional[float], optional): $100 budget for a new SSO sign in user budget_duration: 30d # Optional[str], optional): 30 days budget_duration for a new SSO sign in user models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by a new SSO sign in user + + default_team_params: # Default Params to apply when litellm auto creates a team from SSO IDP provider + max_budget: 100 # Optional[float], optional): $100 budget for the team + budget_duration: 30d # Optional[str], optional): 30 days budget_duration for the team + models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by the team upperbound_key_generate_params: # Upperbound for /key/generate requests when self-serve flow is on diff --git a/litellm/__init__.py b/litellm/__init__.py index e061643398..a3b37da2b4 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -65,6 +65,7 @@ from litellm.proxy._types import ( KeyManagementSystem, KeyManagementSettings, LiteLLM_UpperboundKeyGenerateParams, + NewTeamRequest, ) from litellm.types.utils import StandardKeyGenerationConfig, LlmProviders from litellm.integrations.custom_logger import CustomLogger @@ -126,19 +127,19 @@ prometheus_initialize_budget_metrics: Optional[bool] = False require_auth_for_metrics_endpoint: Optional[bool] = False argilla_batch_size: Optional[int] = None datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload -gcs_pub_sub_use_v1: Optional[ - bool -] = False # if you want to use v1 gcs pubsub logged payload +gcs_pub_sub_use_v1: Optional[bool] = ( + False # if you want to use v1 gcs pubsub logged payload +) argilla_transformation_object: Optional[Dict[str, Any]] = None -_async_input_callback: List[ - Union[str, Callable, CustomLogger] -] = [] # internal variable - async custom callbacks are routed here. -_async_success_callback: List[ - Union[str, Callable, CustomLogger] -] = [] # internal variable - async custom callbacks are routed here. -_async_failure_callback: List[ - Union[str, Callable, CustomLogger] -] = [] # internal variable - async custom callbacks are routed here. +_async_input_callback: List[Union[str, Callable, CustomLogger]] = ( + [] +) # internal variable - async custom callbacks are routed here. +_async_success_callback: List[Union[str, Callable, CustomLogger]] = ( + [] +) # internal variable - async custom callbacks are routed here. +_async_failure_callback: List[Union[str, Callable, CustomLogger]] = ( + [] +) # internal variable - async custom callbacks are routed here. pre_call_rules: List[Callable] = [] post_call_rules: List[Callable] = [] turn_off_message_logging: Optional[bool] = False @@ -146,18 +147,18 @@ log_raw_request_response: bool = False redact_messages_in_exceptions: Optional[bool] = False redact_user_api_key_info: Optional[bool] = False filter_invalid_headers: Optional[bool] = False -add_user_information_to_llm_headers: Optional[ - bool -] = None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers +add_user_information_to_llm_headers: Optional[bool] = ( + None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers +) store_audit_logs = False # Enterprise feature, allow users to see audit logs ### end of callbacks ############# -email: Optional[ - str -] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 -token: Optional[ - str -] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +email: Optional[str] = ( + None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +) +token: Optional[str] = ( + None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +) telemetry = True max_tokens: int = DEFAULT_MAX_TOKENS # OpenAI Defaults drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False)) @@ -233,20 +234,24 @@ enable_loadbalancing_on_batch_endpoints: Optional[bool] = None enable_caching_on_provider_specific_optional_params: bool = ( False # feature-flag for caching on optional params - e.g. 'top_k' ) -caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 -caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 -cache: Optional[ - Cache -] = None # cache object <- use this - https://docs.litellm.ai/docs/caching +caching: bool = ( + False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +) +caching_with_models: bool = ( + False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +) +cache: Optional[Cache] = ( + None # cache object <- use this - https://docs.litellm.ai/docs/caching +) default_in_memory_ttl: Optional[float] = None default_redis_ttl: Optional[float] = None default_redis_batch_cache_expiry: Optional[float] = None model_alias_map: Dict[str, str] = {} model_group_alias_map: Dict[str, str] = {} max_budget: float = 0.0 # set the max budget across all providers -budget_duration: Optional[ - str -] = None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). +budget_duration: Optional[str] = ( + None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). +) default_soft_budget: float = ( DEFAULT_SOFT_BUDGET # by default all litellm proxy keys have a soft budget of 50.0 ) @@ -255,11 +260,15 @@ forward_traceparent_to_llm_provider: bool = False _current_cost = 0.0 # private variable, used if max budget is set error_logs: Dict = {} -add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt +add_function_to_prompt: bool = ( + False # if function calling not supported by api, append function call details to system prompt +) client_session: Optional[httpx.Client] = None aclient_session: Optional[httpx.AsyncClient] = None model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks' -model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" +model_cost_map_url: str = ( + "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" +) suppress_debug_info = False dynamodb_table_name: Optional[str] = None s3_callback_params: Optional[Dict] = None @@ -268,6 +277,7 @@ default_key_generate_params: Optional[Dict] = None upperbound_key_generate_params: Optional[LiteLLM_UpperboundKeyGenerateParams] = None key_generation_settings: Optional[StandardKeyGenerationConfig] = None default_internal_user_params: Optional[Dict] = None +default_team_params: Optional[NewTeamRequest] = None default_team_settings: Optional[List] = None max_user_budget: Optional[float] = None default_max_internal_user_budget: Optional[float] = None @@ -281,7 +291,9 @@ disable_end_user_cost_tracking_prometheus_only: Optional[bool] = None custom_prometheus_metadata_labels: List[str] = [] #### REQUEST PRIORITIZATION #### priority_reservation: Optional[Dict[str, float]] = None -force_ipv4: bool = False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6. +force_ipv4: bool = ( + False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6. +) module_level_aclient = AsyncHTTPHandler( timeout=request_timeout, client_alias="module level aclient" ) @@ -295,13 +307,13 @@ fallbacks: Optional[List] = None context_window_fallbacks: Optional[List] = None content_policy_fallbacks: Optional[List] = None allowed_fails: int = 3 -num_retries_per_request: Optional[ - int -] = None # for the request overall (incl. fallbacks + model retries) +num_retries_per_request: Optional[int] = ( + None # for the request overall (incl. fallbacks + model retries) +) ####### SECRET MANAGERS ##################### -secret_manager_client: Optional[ - Any -] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc. +secret_manager_client: Optional[Any] = ( + None # list of instantiated key management clients - e.g. azure kv, infisical, etc. +) _google_kms_resource_name: Optional[str] = None _key_management_system: Optional[KeyManagementSystem] = None _key_management_settings: KeyManagementSettings = KeyManagementSettings() @@ -1050,10 +1062,10 @@ from .types.llms.custom_llm import CustomLLMItem from .types.utils import GenericStreamingChunk custom_provider_map: List[CustomLLMItem] = [] -_custom_providers: List[ - str -] = [] # internal helper util, used to track names of custom providers -disable_hf_tokenizer_download: Optional[ - bool -] = None # disable huggingface tokenizer download. Defaults to openai clk100 +_custom_providers: List[str] = ( + [] +) # internal helper util, used to track names of custom providers +disable_hf_tokenizer_download: Optional[bool] = ( + None # disable huggingface tokenizer download. Defaults to openai clk100 +) global_disable_no_log_param: bool = False diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index 0cd3600220..1e10aebedb 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -896,6 +896,68 @@ class SSOAuthenticationHandler: sso_teams = getattr(result, "team_ids", []) await add_missing_team_member(user_info=user_info, sso_teams=sso_teams) + @staticmethod + async def create_litellm_team_from_sso_group( + litellm_team_id: str, + litellm_team_name: Optional[str] = None, + ): + """ + Creates a Litellm Team from a SSO Group ID + + Your SSO provider might have groups that should be created on LiteLLM + + Use this helper to create a Litellm Team from a SSO Group ID + + Args: + litellm_team_id (str): The ID of the Litellm Team + litellm_team_name (Optional[str]): The name of the Litellm Team + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise ProxyException( + message="Prisma client not found. Set it in the proxy_server.py file", + type=ProxyErrorTypes.auth_error, + param="prisma_client", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + try: + team_obj = await prisma_client.db.litellm_teamtable.find_first( + where={"team_id": litellm_team_id} + ) + verbose_proxy_logger.debug(f"Team object: {team_obj}") + + # only create a new team if it doesn't exist + if team_obj: + verbose_proxy_logger.debug( + f"Team already exists: {litellm_team_id} - {litellm_team_name}" + ) + return + + team_request: NewTeamRequest = NewTeamRequest( + team_id=litellm_team_id, + team_alias=litellm_team_name, + ) + if litellm.default_team_params: + team_request = litellm.default_team_params.model_copy( + deep=True, + update={ + "team_id": litellm_team_id, + "team_alias": litellm_team_name, + }, + ) + await new_team( + data=team_request, + # params used for Audit Logging + http_request=Request(scope={"type": "http", "method": "POST"}), + user_api_key_dict=UserAPIKeyAuth( + token="", + key_alias=f"litellm.{MicrosoftSSOHandler.__name__}", + ), + ) + except Exception as e: + verbose_proxy_logger.exception(f"Error creating Litellm Team: {e}") + class MicrosoftSSOHandler: """ @@ -1176,15 +1238,6 @@ class MicrosoftSSOHandler: When a user sets a `SERVICE_PRINCIPAL_ID` in the env, litellm will fetch groups under that service principal and create Litellm Teams from them """ - from litellm.proxy.proxy_server import prisma_client - - if prisma_client is None: - raise ProxyException( - message="Prisma client not found. Set it in the proxy_server.py file", - type=ProxyErrorTypes.auth_error, - param="prisma_client", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) verbose_proxy_logger.debug( f"Creating Litellm Teams from Service Principal Teams: {service_principal_teams}" ) @@ -1199,36 +1252,10 @@ class MicrosoftSSOHandler: ) continue - try: - verbose_proxy_logger.debug( - f"Creating Litellm Team: {litellm_team_id} - {litellm_team_name}" - ) - - team_obj = await prisma_client.db.litellm_teamtable.find_first( - where={"team_id": litellm_team_id} - ) - verbose_proxy_logger.debug(f"Team object: {team_obj}") - - # only create a new team if it doesn't exist - if team_obj: - verbose_proxy_logger.debug( - f"Team already exists: {litellm_team_id} - {litellm_team_name}" - ) - continue - await new_team( - data=NewTeamRequest( - team_id=litellm_team_id, - team_alias=litellm_team_name, - ), - # params used for Audit Logging - http_request=Request(scope={"type": "http", "method": "POST"}), - user_api_key_dict=UserAPIKeyAuth( - token="", - key_alias=f"litellm.{MicrosoftSSOHandler.__name__}", - ), - ) - except Exception as e: - verbose_proxy_logger.exception(f"Error creating Litellm Team: {e}") + await SSOAuthenticationHandler.create_litellm_team_from_sso_group( + litellm_team_id=litellm_team_id, + litellm_team_name=litellm_team_name, + ) class GoogleSSOHandler: diff --git a/tests/litellm/proxy/management_endpoints/test_ui_sso.py b/tests/litellm/proxy/management_endpoints/test_ui_sso.py index 606f3833be..ff9700393f 100644 --- a/tests/litellm/proxy/management_endpoints/test_ui_sso.py +++ b/tests/litellm/proxy/management_endpoints/test_ui_sso.py @@ -2,8 +2,9 @@ import asyncio import json import os import sys +import uuid from typing import Optional, cast -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi import Request @@ -13,6 +14,8 @@ sys.path.insert( 0, os.path.abspath("../../../") ) # Adds the parent directory to the system path +import litellm +from litellm.proxy._types import NewTeamRequest from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.management_endpoints.types import CustomOpenID from litellm.proxy.management_endpoints.ui_sso import ( @@ -22,6 +25,7 @@ from litellm.proxy.management_endpoints.ui_sso import ( from litellm.types.proxy.management_endpoints.ui_sso import ( MicrosoftGraphAPIUserGroupDirectoryObject, MicrosoftGraphAPIUserGroupResponse, + MicrosoftServicePrincipalTeam, ) @@ -379,3 +383,96 @@ def test_get_group_ids_from_graph_api_response(): assert len(result) == 2 assert "group1" in result assert "group2" in result + + +@pytest.mark.asyncio +async def test_default_team_params(): + """ + When litellm.default_team_params is set, it should be used to create a new team + """ + # Arrange + litellm.default_team_params = NewTeamRequest( + max_budget=10, budget_duration="1d", models=["special-gpt-5"] + ) + + def mock_jsonify_team_object(db_data): + return db_data + + # Mock Prisma client + mock_prisma = MagicMock() + mock_prisma.db.litellm_teamtable.find_first = AsyncMock(return_value=None) + mock_prisma.db.litellm_teamtable.create = AsyncMock() + mock_prisma.get_data = AsyncMock(return_value=None) + mock_prisma.jsonify_team_object = MagicMock(side_effect=mock_jsonify_team_object) + + with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma): + # Act + team_id = str(uuid.uuid4()) + await MicrosoftSSOHandler.create_litellm_teams_from_service_principal_team_ids( + service_principal_teams=[ + MicrosoftServicePrincipalTeam( + principalId=team_id, + principalDisplayName="Test Team", + ) + ] + ) + + # Assert + # Verify team was created with correct parameters + mock_prisma.db.litellm_teamtable.create.assert_called_once() + print( + "mock_prisma.db.litellm_teamtable.create.call_args", + mock_prisma.db.litellm_teamtable.create.call_args, + ) + create_call_args = mock_prisma.db.litellm_teamtable.create.call_args.kwargs[ + "data" + ] + assert create_call_args["team_id"] == team_id + assert create_call_args["team_alias"] == "Test Team" + assert create_call_args["max_budget"] == 10 + assert create_call_args["budget_duration"] == "1d" + assert create_call_args["models"] == ["special-gpt-5"] + + +@pytest.mark.asyncio +async def test_create_team_without_default_params(): + """ + Test team creation when litellm.default_team_params is None + Should create team with just the basic required fields + """ + # Arrange + litellm.default_team_params = None + + def mock_jsonify_team_object(db_data): + return db_data + + # Mock Prisma client + mock_prisma = MagicMock() + mock_prisma.db.litellm_teamtable.find_first = AsyncMock(return_value=None) + mock_prisma.db.litellm_teamtable.create = AsyncMock() + mock_prisma.get_data = AsyncMock(return_value=None) + mock_prisma.jsonify_team_object = MagicMock(side_effect=mock_jsonify_team_object) + + with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma): + # Act + team_id = str(uuid.uuid4()) + await MicrosoftSSOHandler.create_litellm_teams_from_service_principal_team_ids( + service_principal_teams=[ + MicrosoftServicePrincipalTeam( + principalId=team_id, + principalDisplayName="Test Team", + ) + ] + ) + + # Assert + mock_prisma.db.litellm_teamtable.create.assert_called_once() + create_call_args = mock_prisma.db.litellm_teamtable.create.call_args.kwargs[ + "data" + ] + assert create_call_args["team_id"] == team_id + assert create_call_args["team_alias"] == "Test Team" + # Should not have any of the optional fields + assert "max_budget" not in create_call_args + assert "budget_duration" not in create_call_args + assert create_call_args["models"] == [] From 94a553dbb2e55e9776e875da07c5d322f9b26517 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 10 Apr 2025 16:59:14 -0700 Subject: [PATCH 2/3] [Feat] Emit Key, Team Budget metrics on a cron job schedule (#9528) * _initialize_remaining_budget_metrics * initialize_budget_metrics_cron_job * initialize_budget_metrics_cron_job * initialize_budget_metrics_cron_job * test_initialize_budget_metrics_cron_job * LITELLM_PROXY_ADMIN_NAME * fix code qa checks * test_initialize_budget_metrics_cron_job * test_initialize_budget_metrics_cron_job * pod lock manager allow dynamic cron job ID * fix pod lock manager * require cronjobid for PodLockManager * fix DB_SPEND_UPDATE_JOB_NAME acquire / release lock * add comment on prometheus logger * add debug statements for emitting key, team budget metrics * test_pod_lock_manager.py * test_initialize_budget_metrics_cron_job * initialize_budget_metrics_cron_job * initialize_remaining_budget_metrics * remove outdated test --- litellm/constants.py | 2 + litellm/integrations/prometheus.py | 107 +++++++++++++----- litellm/proxy/db/db_spend_update_writer.py | 10 +- .../db_transaction_queue/pod_lock_manager.py | 54 +++++---- litellm/proxy/proxy_config.yaml | 30 +++++ litellm/proxy/proxy_server.py | 29 +++-- tests/litellm/integrations/test_prometheus.py | 44 +++++++ .../test_pod_lock_manager.py | 106 ++++++++++++----- .../test_prometheus_unit_tests.py | 20 +--- .../test_e2e_pod_lock_manager.py | 86 +++++++++----- 10 files changed, 346 insertions(+), 142 deletions(-) create mode 100644 tests/litellm/integrations/test_prometheus.py diff --git a/litellm/constants.py b/litellm/constants.py index c8248f548a..12bfd17815 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -480,6 +480,7 @@ RESPONSE_FORMAT_TOOL_NAME = "json_tool_call" # default tool name used when conv ########################### Logging Callback Constants ########################### AZURE_STORAGE_MSFT_VERSION = "2019-07-07" +PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES = 5 MCP_TOOL_NAME_PREFIX = "mcp_tool" ########################### LiteLLM Proxy Specific Constants ########################### @@ -514,6 +515,7 @@ LITELLM_PROXY_ADMIN_NAME = "default_user_id" ########################### DB CRON JOB NAMES ########################### DB_SPEND_UPDATE_JOB_NAME = "db_spend_update_job" +PROMETHEUS_EMIT_BUDGET_METRICS_JOB_NAME = "prometheus_emit_budget_metrics_job" DEFAULT_CRON_JOB_LOCK_TTL_SECONDS = 60 # 1 minute PROXY_BUDGET_RESCHEDULER_MIN_TIME = 597 PROXY_BUDGET_RESCHEDULER_MAX_TIME = 605 diff --git a/litellm/integrations/prometheus.py b/litellm/integrations/prometheus.py index 6fba69d005..f61321e53d 100644 --- a/litellm/integrations/prometheus.py +++ b/litellm/integrations/prometheus.py @@ -1,10 +1,19 @@ # used for /metrics endpoint on LiteLLM Proxy #### What this does #### # On success, log events to Prometheus -import asyncio import sys from datetime import datetime, timedelta -from typing import Any, Awaitable, Callable, List, Literal, Optional, Tuple, cast +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + List, + Literal, + Optional, + Tuple, + cast, +) import litellm from litellm._logging import print_verbose, verbose_logger @@ -14,6 +23,11 @@ from litellm.types.integrations.prometheus import * from litellm.types.utils import StandardLoggingPayload from litellm.utils import get_end_user_id_for_cost_tracking +if TYPE_CHECKING: + from apscheduler.schedulers.asyncio import AsyncIOScheduler +else: + AsyncIOScheduler = Any + class PrometheusLogger(CustomLogger): # Class variables or attributes @@ -359,8 +373,6 @@ class PrometheusLogger(CustomLogger): label_name="litellm_requests_metric" ), ) - self._initialize_prometheus_startup_metrics() - except Exception as e: print_verbose(f"Got exception on init prometheus client {str(e)}") raise e @@ -988,9 +1000,9 @@ class PrometheusLogger(CustomLogger): ): try: verbose_logger.debug("setting remaining tokens requests metric") - standard_logging_payload: Optional[ - StandardLoggingPayload - ] = request_kwargs.get("standard_logging_object") + standard_logging_payload: Optional[StandardLoggingPayload] = ( + request_kwargs.get("standard_logging_object") + ) if standard_logging_payload is None: return @@ -1337,24 +1349,6 @@ class PrometheusLogger(CustomLogger): return max_budget - spend - def _initialize_prometheus_startup_metrics(self): - """ - Initialize prometheus startup metrics - - Helper to create tasks for initializing metrics that are required on startup - eg. remaining budget metrics - """ - if litellm.prometheus_initialize_budget_metrics is not True: - verbose_logger.debug("Prometheus: skipping budget metrics initialization") - return - - try: - if asyncio.get_running_loop(): - asyncio.create_task(self._initialize_remaining_budget_metrics()) - except RuntimeError as e: # no running event loop - verbose_logger.exception( - f"No running event loop - skipping budget metrics initialization: {str(e)}" - ) - async def _initialize_budget_metrics( self, data_fetch_function: Callable[..., Awaitable[Tuple[List[Any], Optional[int]]]], @@ -1475,12 +1469,41 @@ class PrometheusLogger(CustomLogger): data_type="keys", ) - async def _initialize_remaining_budget_metrics(self): + async def initialize_remaining_budget_metrics(self): """ - Initialize remaining budget metrics for all teams to avoid metric discrepancies. + Handler for initializing remaining budget metrics for all teams to avoid metric discrepancies. Runs when prometheus logger starts up. + + - If redis cache is available, we use the pod lock manager to acquire a lock and initialize the metrics. + - Ensures only one pod emits the metrics at a time. + - If redis cache is not available, we initialize the metrics directly. """ + from litellm.constants import PROMETHEUS_EMIT_BUDGET_METRICS_JOB_NAME + from litellm.proxy.proxy_server import proxy_logging_obj + + pod_lock_manager = proxy_logging_obj.db_spend_update_writer.pod_lock_manager + + # if using redis, ensure only one pod emits the metrics at a time + if pod_lock_manager and pod_lock_manager.redis_cache: + if await pod_lock_manager.acquire_lock( + cronjob_id=PROMETHEUS_EMIT_BUDGET_METRICS_JOB_NAME + ): + try: + await self._initialize_remaining_budget_metrics() + finally: + await pod_lock_manager.release_lock( + cronjob_id=PROMETHEUS_EMIT_BUDGET_METRICS_JOB_NAME + ) + else: + # if not using redis, initialize the metrics directly + await self._initialize_remaining_budget_metrics() + + async def _initialize_remaining_budget_metrics(self): + """ + Helper to initialize remaining budget metrics for all teams and API keys. + """ + verbose_logger.debug("Emitting key, team budget metrics....") await self._initialize_team_budget_metrics() await self._initialize_api_key_budget_metrics() @@ -1737,6 +1760,36 @@ class PrometheusLogger(CustomLogger): return (end_time - start_time).total_seconds() return None + @staticmethod + def initialize_budget_metrics_cron_job(scheduler: AsyncIOScheduler): + """ + Initialize budget metrics as a cron job. This job runs every `PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES` minutes. + + It emits the current remaining budget metrics for all Keys and Teams. + """ + from litellm.constants import PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES + from litellm.integrations.custom_logger import CustomLogger + from litellm.integrations.prometheus import PrometheusLogger + + prometheus_loggers: List[CustomLogger] = ( + litellm.logging_callback_manager.get_custom_loggers_for_type( + callback_type=PrometheusLogger + ) + ) + # we need to get the initialized prometheus logger instance(s) and call logger.initialize_remaining_budget_metrics() on them + verbose_logger.debug("found %s prometheus loggers", len(prometheus_loggers)) + if len(prometheus_loggers) > 0: + prometheus_logger = cast(PrometheusLogger, prometheus_loggers[0]) + verbose_logger.debug( + "Initializing remaining budget metrics as a cron job executing every %s minutes" + % PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES + ) + scheduler.add_job( + prometheus_logger.initialize_remaining_budget_metrics, + "interval", + minutes=PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES, + ) + @staticmethod def _mount_metrics_endpoint(premium_user: bool): """ diff --git a/litellm/proxy/db/db_spend_update_writer.py b/litellm/proxy/db/db_spend_update_writer.py index b32dc5c691..12ae51822c 100644 --- a/litellm/proxy/db/db_spend_update_writer.py +++ b/litellm/proxy/db/db_spend_update_writer.py @@ -53,7 +53,7 @@ class DBSpendUpdateWriter: ): self.redis_cache = redis_cache self.redis_update_buffer = RedisUpdateBuffer(redis_cache=self.redis_cache) - self.pod_lock_manager = PodLockManager(cronjob_id=DB_SPEND_UPDATE_JOB_NAME) + self.pod_lock_manager = PodLockManager() self.spend_update_queue = SpendUpdateQueue() self.daily_spend_update_queue = DailySpendUpdateQueue() @@ -383,7 +383,9 @@ class DBSpendUpdateWriter: ) # Only commit from redis to db if this pod is the leader - if await self.pod_lock_manager.acquire_lock(): + if await self.pod_lock_manager.acquire_lock( + cronjob_id=DB_SPEND_UPDATE_JOB_NAME, + ): verbose_proxy_logger.debug("acquired lock for spend updates") try: @@ -411,7 +413,9 @@ class DBSpendUpdateWriter: except Exception as e: verbose_proxy_logger.error(f"Error committing spend updates: {e}") finally: - await self.pod_lock_manager.release_lock() + await self.pod_lock_manager.release_lock( + cronjob_id=DB_SPEND_UPDATE_JOB_NAME, + ) async def _commit_spend_updates_to_db_without_redis_buffer( self, diff --git a/litellm/proxy/db/db_transaction_queue/pod_lock_manager.py b/litellm/proxy/db/db_transaction_queue/pod_lock_manager.py index cb4a43a802..be3be64546 100644 --- a/litellm/proxy/db/db_transaction_queue/pod_lock_manager.py +++ b/litellm/proxy/db/db_transaction_queue/pod_lock_manager.py @@ -21,18 +21,18 @@ class PodLockManager: Ensures that only one pod can run a cron job at a time. """ - def __init__(self, cronjob_id: str, redis_cache: Optional[RedisCache] = None): + def __init__(self, redis_cache: Optional[RedisCache] = None): self.pod_id = str(uuid.uuid4()) - self.cronjob_id = cronjob_id self.redis_cache = redis_cache - # Define a unique key for this cronjob lock in Redis. - self.lock_key = PodLockManager.get_redis_lock_key(cronjob_id) @staticmethod def get_redis_lock_key(cronjob_id: str) -> str: return f"cronjob_lock:{cronjob_id}" - async def acquire_lock(self) -> Optional[bool]: + async def acquire_lock( + self, + cronjob_id: str, + ) -> Optional[bool]: """ Attempt to acquire the lock for a specific cron job using Redis. Uses the SET command with NX and EX options to ensure atomicity. @@ -44,12 +44,13 @@ class PodLockManager: verbose_proxy_logger.debug( "Pod %s attempting to acquire Redis lock for cronjob_id=%s", self.pod_id, - self.cronjob_id, + cronjob_id, ) # Try to set the lock key with the pod_id as its value, only if it doesn't exist (NX) # and with an expiration (EX) to avoid deadlocks. + lock_key = PodLockManager.get_redis_lock_key(cronjob_id) acquired = await self.redis_cache.async_set_cache( - self.lock_key, + lock_key, self.pod_id, nx=True, ttl=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS, @@ -58,13 +59,13 @@ class PodLockManager: verbose_proxy_logger.info( "Pod %s successfully acquired Redis lock for cronjob_id=%s", self.pod_id, - self.cronjob_id, + cronjob_id, ) return True else: # Check if the current pod already holds the lock - current_value = await self.redis_cache.async_get_cache(self.lock_key) + current_value = await self.redis_cache.async_get_cache(lock_key) if current_value is not None: if isinstance(current_value, bytes): current_value = current_value.decode("utf-8") @@ -72,18 +73,21 @@ class PodLockManager: verbose_proxy_logger.info( "Pod %s already holds the Redis lock for cronjob_id=%s", self.pod_id, - self.cronjob_id, + cronjob_id, ) - self._emit_acquired_lock_event(self.cronjob_id, self.pod_id) + self._emit_acquired_lock_event(cronjob_id, self.pod_id) return True return False except Exception as e: verbose_proxy_logger.error( - f"Error acquiring Redis lock for {self.cronjob_id}: {e}" + f"Error acquiring Redis lock for {cronjob_id}: {e}" ) return False - async def release_lock(self): + async def release_lock( + self, + cronjob_id: str, + ): """ Release the lock if the current pod holds it. Uses get and delete commands to ensure that only the owner can release the lock. @@ -92,46 +96,52 @@ class PodLockManager: verbose_proxy_logger.debug("redis_cache is None, skipping release_lock") return try: + cronjob_id = cronjob_id verbose_proxy_logger.debug( "Pod %s attempting to release Redis lock for cronjob_id=%s", self.pod_id, - self.cronjob_id, + cronjob_id, ) - current_value = await self.redis_cache.async_get_cache(self.lock_key) + lock_key = PodLockManager.get_redis_lock_key(cronjob_id) + + current_value = await self.redis_cache.async_get_cache(lock_key) if current_value is not None: if isinstance(current_value, bytes): current_value = current_value.decode("utf-8") if current_value == self.pod_id: - result = await self.redis_cache.async_delete_cache(self.lock_key) + result = await self.redis_cache.async_delete_cache(lock_key) if result == 1: verbose_proxy_logger.info( "Pod %s successfully released Redis lock for cronjob_id=%s", self.pod_id, - self.cronjob_id, + cronjob_id, + ) + self._emit_released_lock_event( + cronjob_id=cronjob_id, + pod_id=self.pod_id, ) - self._emit_released_lock_event(self.cronjob_id, self.pod_id) else: verbose_proxy_logger.debug( "Pod %s failed to release Redis lock for cronjob_id=%s", self.pod_id, - self.cronjob_id, + cronjob_id, ) else: verbose_proxy_logger.debug( "Pod %s cannot release Redis lock for cronjob_id=%s because it is held by pod %s", self.pod_id, - self.cronjob_id, + cronjob_id, current_value, ) else: verbose_proxy_logger.debug( "Pod %s attempted to release Redis lock for cronjob_id=%s, but no lock was found", self.pod_id, - self.cronjob_id, + cronjob_id, ) except Exception as e: verbose_proxy_logger.error( - f"Error releasing Redis lock for {self.cronjob_id}: {e}" + f"Error releasing Redis lock for {cronjob_id}: {e}" ) @staticmethod diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 23de923db7..847ca7ce56 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -10,6 +10,36 @@ model_list: api_key: fake-key litellm_settings: + prometheus_initialize_budget_metrics: true + callbacks: ["prometheus"] + +mcp_tools: + - name: "get_current_time" + description: "Get the current time" + input_schema: { + "type": "object", + "properties": { + "format": { + "type": "string", + "description": "The format of the time to return", + "enum": ["short"] + } + } + } + handler: "mcp_tools.get_current_time" + - name: "get_current_date" + description: "Get the current date" + input_schema: { + "type": "object", + "properties": { + "format": { + "type": "string", + "description": "The format of the date to return", + "enum": ["short"] + } + } + } + handler: "mcp_tools.get_current_date" default_team_settings: - team_id: test_dev success_callback: ["langfuse", "s3"] diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index ddfb7118d7..84b515f405 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -803,9 +803,9 @@ model_max_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter( dual_cache=user_api_key_cache ) litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter) -redis_usage_cache: Optional[ - RedisCache -] = None # redis cache used for tracking spend, tpm/rpm limits +redis_usage_cache: Optional[RedisCache] = ( + None # redis cache used for tracking spend, tpm/rpm limits +) user_custom_auth = None user_custom_key_generate = None user_custom_sso = None @@ -1131,9 +1131,9 @@ async def update_cache( # noqa: PLR0915 _id = "team_id:{}".format(team_id) try: # Fetch the existing cost for the given user - existing_spend_obj: Optional[ - LiteLLM_TeamTable - ] = await user_api_key_cache.async_get_cache(key=_id) + existing_spend_obj: Optional[LiteLLM_TeamTable] = ( + await user_api_key_cache.async_get_cache(key=_id) + ) if existing_spend_obj is None: # do nothing if team not in api key cache return @@ -2812,9 +2812,9 @@ async def initialize( # noqa: PLR0915 user_api_base = api_base dynamic_config[user_model]["api_base"] = api_base if api_version: - os.environ[ - "AZURE_API_VERSION" - ] = api_version # set this for azure - litellm can read this from the env + os.environ["AZURE_API_VERSION"] = ( + api_version # set this for azure - litellm can read this from the env + ) if max_tokens: # model-specific param dynamic_config[user_model]["max_tokens"] = max_tokens if temperature: # model-specific param @@ -3191,6 +3191,11 @@ class ProxyStartupEvent: ) await proxy_logging_obj.slack_alerting_instance.send_fallback_stats_from_prometheus() + if litellm.prometheus_initialize_budget_metrics is True: + from litellm.integrations.prometheus import PrometheusLogger + + PrometheusLogger.initialize_budget_metrics_cron_job(scheduler=scheduler) + scheduler.start() @classmethod @@ -7753,9 +7758,9 @@ async def get_config_list( hasattr(sub_field_info, "description") and sub_field_info.description is not None ): - nested_fields[ - idx - ].field_description = sub_field_info.description + nested_fields[idx].field_description = ( + sub_field_info.description + ) idx += 1 _stored_in_db = None diff --git a/tests/litellm/integrations/test_prometheus.py b/tests/litellm/integrations/test_prometheus.py new file mode 100644 index 0000000000..464477f019 --- /dev/null +++ b/tests/litellm/integrations/test_prometheus.py @@ -0,0 +1,44 @@ +""" +Mock prometheus unit tests, these don't rely on LLM API calls +""" + +import json +import os +import sys + +import pytest +from fastapi.testclient import TestClient + +sys.path.insert( + 0, os.path.abspath("../../..") +) # Adds the parent directory to the system path + +from apscheduler.schedulers.asyncio import AsyncIOScheduler + +import litellm +from litellm.constants import PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES +from litellm.integrations.prometheus import PrometheusLogger + + +def test_initialize_budget_metrics_cron_job(): + # Create a scheduler + scheduler = AsyncIOScheduler() + + # Create and register a PrometheusLogger + prometheus_logger = PrometheusLogger() + litellm.callbacks = [prometheus_logger] + + # Initialize the cron job + PrometheusLogger.initialize_budget_metrics_cron_job(scheduler) + + # Verify that a job was added to the scheduler + jobs = scheduler.get_jobs() + assert len(jobs) == 1 + + # Verify job properties + job = jobs[0] + assert ( + job.trigger.interval.total_seconds() / 60 + == PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES + ) + assert job.func.__name__ == "initialize_remaining_budget_metrics" diff --git a/tests/litellm/proxy/db/db_transaction_queue/test_pod_lock_manager.py b/tests/litellm/proxy/db/db_transaction_queue/test_pod_lock_manager.py index 697d985dc9..e83fd75c3a 100644 --- a/tests/litellm/proxy/db/db_transaction_queue/test_pod_lock_manager.py +++ b/tests/litellm/proxy/db/db_transaction_queue/test_pod_lock_manager.py @@ -29,7 +29,7 @@ def mock_redis(): @pytest.fixture def pod_lock_manager(mock_redis): - return PodLockManager(cronjob_id="test_job", redis_cache=mock_redis) + return PodLockManager(redis_cache=mock_redis) @pytest.mark.asyncio @@ -40,12 +40,15 @@ async def test_acquire_lock_success(pod_lock_manager, mock_redis): # Mock successful acquisition (SET NX returns True) mock_redis.async_set_cache.return_value = True - result = await pod_lock_manager.acquire_lock() + result = await pod_lock_manager.acquire_lock( + cronjob_id="test_job", + ) assert result == True # Verify set_cache was called with correct parameters + lock_key = pod_lock_manager.get_redis_lock_key(cronjob_id="test_job") mock_redis.async_set_cache.assert_called_once_with( - pod_lock_manager.lock_key, + lock_key, pod_lock_manager.pod_id, nx=True, ttl=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS, @@ -62,13 +65,16 @@ async def test_acquire_lock_existing_active(pod_lock_manager, mock_redis): # Mock get_cache to return a different pod's ID mock_redis.async_get_cache.return_value = "different_pod_id" - result = await pod_lock_manager.acquire_lock() + result = await pod_lock_manager.acquire_lock( + cronjob_id="test_job", + ) assert result == False # Verify set_cache was called mock_redis.async_set_cache.assert_called_once() # Verify get_cache was called to check existing lock - mock_redis.async_get_cache.assert_called_once_with(pod_lock_manager.lock_key) + lock_key = pod_lock_manager.get_redis_lock_key(cronjob_id="test_job") + mock_redis.async_get_cache.assert_called_once_with(lock_key) @pytest.mark.asyncio @@ -89,7 +95,9 @@ async def test_acquire_lock_expired(pod_lock_manager, mock_redis): # Then set succeeds on retry (simulating key expiring between checks) mock_redis.async_set_cache.side_effect = [False, True] - result = await pod_lock_manager.acquire_lock() + result = await pod_lock_manager.acquire_lock( + cronjob_id="test_job", + ) assert result == False # First attempt fails # Reset mock for a second attempt @@ -97,7 +105,9 @@ async def test_acquire_lock_expired(pod_lock_manager, mock_redis): mock_redis.async_set_cache.return_value = True # Try again (simulating the lock expired) - result = await pod_lock_manager.acquire_lock() + result = await pod_lock_manager.acquire_lock( + cronjob_id="test_job", + ) assert result == True # Verify set_cache was called again @@ -114,12 +124,15 @@ async def test_release_lock_success(pod_lock_manager, mock_redis): # Mock successful deletion mock_redis.async_delete_cache.return_value = 1 - await pod_lock_manager.release_lock() + await pod_lock_manager.release_lock( + cronjob_id="test_job", + ) # Verify get_cache was called - mock_redis.async_get_cache.assert_called_once_with(pod_lock_manager.lock_key) + lock_key = pod_lock_manager.get_redis_lock_key(cronjob_id="test_job") + mock_redis.async_get_cache.assert_called_once_with(lock_key) # Verify delete_cache was called - mock_redis.async_delete_cache.assert_called_once_with(pod_lock_manager.lock_key) + mock_redis.async_delete_cache.assert_called_once_with(lock_key) @pytest.mark.asyncio @@ -130,10 +143,13 @@ async def test_release_lock_different_pod(pod_lock_manager, mock_redis): # Mock get_cache to return a different pod's ID mock_redis.async_get_cache.return_value = "different_pod_id" - await pod_lock_manager.release_lock() + await pod_lock_manager.release_lock( + cronjob_id="test_job", + ) # Verify get_cache was called - mock_redis.async_get_cache.assert_called_once_with(pod_lock_manager.lock_key) + lock_key = pod_lock_manager.get_redis_lock_key(cronjob_id="test_job") + mock_redis.async_get_cache.assert_called_once_with(lock_key) # Verify delete_cache was NOT called mock_redis.async_delete_cache.assert_not_called() @@ -146,10 +162,13 @@ async def test_release_lock_no_lock(pod_lock_manager, mock_redis): # Mock get_cache to return None (no lock) mock_redis.async_get_cache.return_value = None - await pod_lock_manager.release_lock() + await pod_lock_manager.release_lock( + cronjob_id="test_job", + ) # Verify get_cache was called - mock_redis.async_get_cache.assert_called_once_with(pod_lock_manager.lock_key) + lock_key = pod_lock_manager.get_redis_lock_key(cronjob_id="test_job") + mock_redis.async_get_cache.assert_called_once_with(lock_key) # Verify delete_cache was NOT called mock_redis.async_delete_cache.assert_not_called() @@ -159,13 +178,20 @@ async def test_redis_none(monkeypatch): """ Test behavior when redis_cache is None """ - pod_lock_manager = PodLockManager(cronjob_id="test_job", redis_cache=None) + pod_lock_manager = PodLockManager(redis_cache=None) # Test acquire_lock with None redis_cache - assert await pod_lock_manager.acquire_lock() is None + assert ( + await pod_lock_manager.acquire_lock( + cronjob_id="test_job", + ) + is None + ) # Test release_lock with None redis_cache (should not raise exception) - await pod_lock_manager.release_lock() + await pod_lock_manager.release_lock( + cronjob_id="test_job", + ) @pytest.mark.asyncio @@ -179,7 +205,9 @@ async def test_redis_error_handling(pod_lock_manager, mock_redis): mock_redis.async_delete_cache.side_effect = Exception("Redis error") # Test acquire_lock error handling - result = await pod_lock_manager.acquire_lock() + result = await pod_lock_manager.acquire_lock( + cronjob_id="test_job", + ) assert result == False # Reset side effect for get_cache for the release test @@ -187,7 +215,9 @@ async def test_redis_error_handling(pod_lock_manager, mock_redis): mock_redis.async_get_cache.return_value = pod_lock_manager.pod_id # Test release_lock error handling (should not raise exception) - await pod_lock_manager.release_lock() + await pod_lock_manager.release_lock( + cronjob_id="test_job", + ) @pytest.mark.asyncio @@ -200,14 +230,18 @@ async def test_bytes_handling(pod_lock_manager, mock_redis): # Mock get_cache to return bytes mock_redis.async_get_cache.return_value = pod_lock_manager.pod_id.encode("utf-8") - result = await pod_lock_manager.acquire_lock() + result = await pod_lock_manager.acquire_lock( + cronjob_id="test_job", + ) assert result == True # Reset for release test mock_redis.async_get_cache.return_value = pod_lock_manager.pod_id.encode("utf-8") mock_redis.async_delete_cache.return_value = 1 - await pod_lock_manager.release_lock() + await pod_lock_manager.release_lock( + cronjob_id="test_job", + ) mock_redis.async_delete_cache.assert_called_once() @@ -217,15 +251,17 @@ async def test_concurrent_lock_acquisition_simulation(): Simulate multiple pods trying to acquire the lock simultaneously """ mock_redis = MockRedisCache() - pod1 = PodLockManager(cronjob_id="test_job", redis_cache=mock_redis) - pod2 = PodLockManager(cronjob_id="test_job", redis_cache=mock_redis) - pod3 = PodLockManager(cronjob_id="test_job", redis_cache=mock_redis) + pod1 = PodLockManager(redis_cache=mock_redis) + pod2 = PodLockManager(redis_cache=mock_redis) + pod3 = PodLockManager(redis_cache=mock_redis) # Simulate first pod getting the lock mock_redis.async_set_cache.return_value = True # First pod should get the lock - result1 = await pod1.acquire_lock() + result1 = await pod1.acquire_lock( + cronjob_id="test_job", + ) assert result1 == True # Simulate other pods failing to get the lock @@ -233,8 +269,12 @@ async def test_concurrent_lock_acquisition_simulation(): mock_redis.async_get_cache.return_value = pod1.pod_id # Other pods should fail to acquire - result2 = await pod2.acquire_lock() - result3 = await pod3.acquire_lock() + result2 = await pod2.acquire_lock( + cronjob_id="test_job", + ) + result3 = await pod3.acquire_lock( + cronjob_id="test_job", + ) # Since other pods don't have the lock, they should get False assert result2 == False @@ -246,14 +286,16 @@ async def test_lock_takeover_race_condition(mock_redis): """ Test scenario where multiple pods try to take over an expired lock using Redis """ - pod1 = PodLockManager(cronjob_id="test_job", redis_cache=mock_redis) - pod2 = PodLockManager(cronjob_id="test_job", redis_cache=mock_redis) + pod1 = PodLockManager(redis_cache=mock_redis) + pod2 = PodLockManager(redis_cache=mock_redis) # Simulate first pod's acquisition succeeding mock_redis.async_set_cache.return_value = True # First pod should successfully acquire - result1 = await pod1.acquire_lock() + result1 = await pod1.acquire_lock( + cronjob_id="test_job", + ) assert result1 == True # Simulate race condition: second pod tries but fails @@ -261,5 +303,7 @@ async def test_lock_takeover_race_condition(mock_redis): mock_redis.async_get_cache.return_value = pod1.pod_id # Second pod should fail to acquire - result2 = await pod2.acquire_lock() + result2 = await pod2.acquire_lock( + cronjob_id="test_job", + ) assert result2 == False diff --git a/tests/logging_callback_tests/test_prometheus_unit_tests.py b/tests/logging_callback_tests/test_prometheus_unit_tests.py index ddfce710d7..0b58bc7aaf 100644 --- a/tests/logging_callback_tests/test_prometheus_unit_tests.py +++ b/tests/logging_callback_tests/test_prometheus_unit_tests.py @@ -39,7 +39,7 @@ import time @pytest.fixture -def prometheus_logger(): +def prometheus_logger() -> PrometheusLogger: collectors = list(REGISTRY._collector_to_names.keys()) for collector in collectors: REGISTRY.unregister(collector) @@ -1212,24 +1212,6 @@ async def test_initialize_remaining_budget_metrics_exception_handling( prometheus_logger.litellm_remaining_api_key_budget_metric.assert_not_called() -def test_initialize_prometheus_startup_metrics_no_loop(prometheus_logger): - """ - Test that _initialize_prometheus_startup_metrics handles case when no event loop exists - """ - # Mock asyncio.get_running_loop to raise RuntimeError - litellm.prometheus_initialize_budget_metrics = True - with patch( - "asyncio.get_running_loop", side_effect=RuntimeError("No running event loop") - ), patch("litellm._logging.verbose_logger.exception") as mock_logger: - - # Call the function - prometheus_logger._initialize_prometheus_startup_metrics() - - # Verify the error was logged - mock_logger.assert_called_once() - assert "No running event loop" in mock_logger.call_args[0][0] - - @pytest.mark.asyncio(scope="session") async def test_initialize_api_key_budget_metrics(prometheus_logger): """ diff --git a/tests/proxy_unit_tests/test_e2e_pod_lock_manager.py b/tests/proxy_unit_tests/test_e2e_pod_lock_manager.py index 652b1838ac..061da8c186 100644 --- a/tests/proxy_unit_tests/test_e2e_pod_lock_manager.py +++ b/tests/proxy_unit_tests/test_e2e_pod_lock_manager.py @@ -141,10 +141,12 @@ async def setup_db_connection(prisma_client): async def test_pod_lock_acquisition_when_no_active_lock(): """Test if a pod can acquire a lock when no lock is active""" cronjob_id = str(uuid.uuid4()) - lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) + lock_manager = PodLockManager(redis_cache=global_redis_cache) # Attempt to acquire lock - result = await lock_manager.acquire_lock() + result = await lock_manager.acquire_lock( + cronjob_id=cronjob_id, + ) assert result == True, "Pod should be able to acquire lock when no lock exists" @@ -161,13 +163,19 @@ async def test_pod_lock_acquisition_after_completion(): """Test if a new pod can acquire lock after previous pod completes""" cronjob_id = str(uuid.uuid4()) # First pod acquires and releases lock - first_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) - await first_lock_manager.acquire_lock() - await first_lock_manager.release_lock() + first_lock_manager = PodLockManager(redis_cache=global_redis_cache) + await first_lock_manager.acquire_lock( + cronjob_id=cronjob_id, + ) + await first_lock_manager.release_lock( + cronjob_id=cronjob_id, + ) # Second pod attempts to acquire lock - second_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) - result = await second_lock_manager.acquire_lock() + second_lock_manager = PodLockManager(redis_cache=global_redis_cache) + result = await second_lock_manager.acquire_lock( + cronjob_id=cronjob_id, + ) assert result == True, "Second pod should acquire lock after first pod releases it" @@ -182,15 +190,21 @@ async def test_pod_lock_acquisition_after_expiry(): """Test if a new pod can acquire lock after previous pod's lock expires""" cronjob_id = str(uuid.uuid4()) # First pod acquires lock - first_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) - await first_lock_manager.acquire_lock() + first_lock_manager = PodLockManager(redis_cache=global_redis_cache) + await first_lock_manager.acquire_lock( + cronjob_id=cronjob_id, + ) # release the lock from the first pod - await first_lock_manager.release_lock() + await first_lock_manager.release_lock( + cronjob_id=cronjob_id, + ) # Second pod attempts to acquire lock - second_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) - result = await second_lock_manager.acquire_lock() + second_lock_manager = PodLockManager(redis_cache=global_redis_cache) + result = await second_lock_manager.acquire_lock( + cronjob_id=cronjob_id, + ) assert ( result == True @@ -206,11 +220,15 @@ async def test_pod_lock_acquisition_after_expiry(): async def test_pod_lock_release(): """Test if a pod can successfully release its lock""" cronjob_id = str(uuid.uuid4()) - lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) + lock_manager = PodLockManager(redis_cache=global_redis_cache) # Acquire and then release lock - await lock_manager.acquire_lock() - await lock_manager.release_lock() + await lock_manager.acquire_lock( + cronjob_id=cronjob_id, + ) + await lock_manager.release_lock( + cronjob_id=cronjob_id, + ) # Verify in redis lock_key = PodLockManager.get_redis_lock_key(cronjob_id) @@ -224,15 +242,21 @@ async def test_concurrent_lock_acquisition(): cronjob_id = str(uuid.uuid4()) # Create multiple lock managers simulating different pods - lock_manager1 = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) - lock_manager2 = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) - lock_manager3 = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) + lock_manager1 = PodLockManager(redis_cache=global_redis_cache) + lock_manager2 = PodLockManager(redis_cache=global_redis_cache) + lock_manager3 = PodLockManager(redis_cache=global_redis_cache) # Try to acquire locks concurrently results = await asyncio.gather( - lock_manager1.acquire_lock(), - lock_manager2.acquire_lock(), - lock_manager3.acquire_lock(), + lock_manager1.acquire_lock( + cronjob_id=cronjob_id, + ), + lock_manager2.acquire_lock( + cronjob_id=cronjob_id, + ), + lock_manager3.acquire_lock( + cronjob_id=cronjob_id, + ), ) # Only one should succeed @@ -254,7 +278,7 @@ async def test_concurrent_lock_acquisition(): async def test_lock_acquisition_with_expired_ttl(): """Test that a pod can acquire a lock when existing lock has expired TTL""" cronjob_id = str(uuid.uuid4()) - first_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) + first_lock_manager = PodLockManager(redis_cache=global_redis_cache) # First pod acquires lock with a very short TTL to simulate expiration short_ttl = 1 # 1 second @@ -269,8 +293,10 @@ async def test_lock_acquisition_with_expired_ttl(): await asyncio.sleep(short_ttl + 0.5) # Wait slightly longer than the TTL # Second pod tries to acquire without explicit release - second_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) - result = await second_lock_manager.acquire_lock() + second_lock_manager = PodLockManager(redis_cache=global_redis_cache) + result = await second_lock_manager.acquire_lock( + cronjob_id=cronjob_id, + ) assert result == True, "Should acquire lock when existing lock has expired TTL" @@ -286,7 +312,7 @@ async def test_release_expired_lock(): cronjob_id = str(uuid.uuid4()) # First pod acquires lock with a very short TTL - first_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) + first_lock_manager = PodLockManager(redis_cache=global_redis_cache) short_ttl = 1 # 1 second lock_key = PodLockManager.get_redis_lock_key(cronjob_id) await global_redis_cache.async_set_cache( @@ -299,11 +325,15 @@ async def test_release_expired_lock(): await asyncio.sleep(short_ttl + 0.5) # Wait slightly longer than the TTL # Second pod acquires the lock - second_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) - await second_lock_manager.acquire_lock() + second_lock_manager = PodLockManager(redis_cache=global_redis_cache) + await second_lock_manager.acquire_lock( + cronjob_id=cronjob_id, + ) # First pod attempts to release its lock - await first_lock_manager.release_lock() + await first_lock_manager.release_lock( + cronjob_id=cronjob_id, + ) # Verify that second pod's lock is still active lock_record = await global_redis_cache.async_get_cache(lock_key) From 72a12e91c4eaf62532e0257427407a7f622db8b7 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 10 Apr 2025 17:40:58 -0700 Subject: [PATCH 3/3] [Bug Fix MSFT SSO] Use correct field for user email when using MSFT SSO (#9886) * fix openid_from_response * test_microsoft_sso_handler_openid_from_response_user_principal_name * test upsert_sso_user --- litellm/proxy/management_endpoints/ui_sso.py | 26 +----- .../proxy/management_endpoints/test_ui_sso.py | 80 +++++++++++++++++++ 2 files changed, 83 insertions(+), 23 deletions(-) diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index 1e10aebedb..0365336e73 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -468,9 +468,6 @@ async def auth_callback(request: Request): # noqa: PLR0915 result=result, user_info=user_info, user_email=user_email, - user_id_models=user_id_models, - max_internal_user_budget=max_internal_user_budget, - internal_user_budget_duration=internal_user_budget_duration, user_defined_values=user_defined_values, prisma_client=prisma_client, ) @@ -831,37 +828,20 @@ class SSOAuthenticationHandler: result: Optional[Union[CustomOpenID, OpenID, dict]], user_info: Optional[Union[NewUserResponse, LiteLLM_UserTable]], user_email: Optional[str], - user_id_models: List[str], - max_internal_user_budget: Optional[float], - internal_user_budget_duration: Optional[str], user_defined_values: Optional[SSOUserDefinedValues], prisma_client: PrismaClient, ): """ Connects the SSO Users to the User Table in LiteLLM DB - - If user on LiteLLM DB, update the user_id with the SSO user_id + - If user on LiteLLM DB, update the user_email with the SSO user_email - If user not on LiteLLM DB, insert the user into LiteLLM DB """ try: if user_info is not None: user_id = user_info.user_id - user_defined_values = SSOUserDefinedValues( - models=getattr(user_info, "models", user_id_models), - user_id=user_info.user_id or "", - user_email=getattr(user_info, "user_email", user_email), - user_role=getattr(user_info, "user_role", None), - max_budget=getattr( - user_info, "max_budget", max_internal_user_budget - ), - budget_duration=getattr( - user_info, "budget_duration", internal_user_budget_duration - ), - ) - - # update id await prisma_client.db.litellm_usertable.update_many( - where={"user_email": user_email}, data={"user_id": user_id} # type: ignore + where={"user_id": user_id}, data={"user_email": user_email} ) else: verbose_proxy_logger.info( @@ -1045,7 +1025,7 @@ class MicrosoftSSOHandler: response = response or {} verbose_proxy_logger.debug(f"Microsoft SSO Callback Response: {response}") openid_response = CustomOpenID( - email=response.get("mail"), + email=response.get("userPrincipalName") or response.get("mail"), display_name=response.get("displayName"), provider="microsoft", id=response.get("id"), diff --git a/tests/litellm/proxy/management_endpoints/test_ui_sso.py b/tests/litellm/proxy/management_endpoints/test_ui_sso.py index ff9700393f..09e337bf84 100644 --- a/tests/litellm/proxy/management_endpoints/test_ui_sso.py +++ b/tests/litellm/proxy/management_endpoints/test_ui_sso.py @@ -21,6 +21,7 @@ from litellm.proxy.management_endpoints.types import CustomOpenID from litellm.proxy.management_endpoints.ui_sso import ( GoogleSSOHandler, MicrosoftSSOHandler, + SSOAuthenticationHandler, ) from litellm.types.proxy.management_endpoints.ui_sso import ( MicrosoftGraphAPIUserGroupDirectoryObject, @@ -29,6 +30,37 @@ from litellm.types.proxy.management_endpoints.ui_sso import ( ) +def test_microsoft_sso_handler_openid_from_response_user_principal_name(): + # Arrange + # Create a mock response similar to what Microsoft SSO would return + mock_response = { + "userPrincipalName": "test@example.com", + "displayName": "Test User", + "id": "user123", + "givenName": "Test", + "surname": "User", + "some_other_field": "value", + } + expected_team_ids = ["team1", "team2"] + # Act + # Call the method being tested + result = MicrosoftSSOHandler.openid_from_response( + response=mock_response, team_ids=expected_team_ids + ) + + # Assert + + # Check that the result is a CustomOpenID object with the expected values + assert isinstance(result, CustomOpenID) + assert result.email == "test@example.com" + assert result.display_name == "Test User" + assert result.provider == "microsoft" + assert result.id == "user123" + assert result.first_name == "Test" + assert result.last_name == "User" + assert result.team_ids == expected_team_ids + + def test_microsoft_sso_handler_openid_from_response(): # Arrange # Create a mock response similar to what Microsoft SSO would return @@ -386,6 +418,54 @@ def test_get_group_ids_from_graph_api_response(): @pytest.mark.asyncio +async def test_upsert_sso_user_existing_user(): + """ + If a user_id is already in the LiteLLM DB and the user signed in with SSO. Ensure that the user_id is updated with the SSO user_email + + SSO Test + """ + # Arrange + mock_prisma = MagicMock() + mock_prisma.db = MagicMock() + mock_prisma.db.litellm_usertable = MagicMock() + mock_prisma.db.litellm_usertable.update_many = AsyncMock() + + # Create a mock existing user + mock_user = MagicMock() + mock_user.user_id = "existing_user_123" + mock_user.user_email = "old_email@example.com" + + # Create mock SSO response + mock_sso_response = CustomOpenID( + email="new_email@example.com", + display_name="Test User", + provider="microsoft", + id="existing_user_123", + first_name="Test", + last_name="User", + team_ids=[], + ) + + # Create mock user defined values + mock_user_defined_values = MagicMock() + + # Act + result = await SSOAuthenticationHandler.upsert_sso_user( + result=mock_sso_response, + user_info=mock_user, + user_email="new_email@example.com", + user_defined_values=mock_user_defined_values, + prisma_client=mock_prisma, + ) + + # Assert + mock_prisma.db.litellm_usertable.update_many.assert_called_once_with( + where={"user_id": "existing_user_123"}, + data={"user_email": "new_email@example.com"}, + ) + assert result == mock_user + + async def test_default_team_params(): """ When litellm.default_team_params is set, it should be used to create a new team