diff --git a/docs/my-website/docs/proxy/self_serve.md b/docs/my-website/docs/proxy/self_serve.md
index d630c8e7f3..2fc17d952e 100644
--- a/docs/my-website/docs/proxy/self_serve.md
+++ b/docs/my-website/docs/proxy/self_serve.md
@@ -207,9 +207,14 @@ This walks through setting up sso auto-add for **Microsoft Entra ID**
Follow along this video for a walkthrough of how to set this up with Microsoft Entra ID
-
+
+
+
+**Next steps**
+
+1. [Set default params for new teams auto-created from SSO](#set-default-params-for-new-teams)
### Debugging SSO JWT fields
@@ -279,6 +284,26 @@ This budget does not apply to keys created under non-default teams.
[**Go Here**](./team_budgets.md)
+### Set default params for new teams
+
+When you connect litellm to your SSO provider, litellm can auto-create teams. Use this to set the default `models`, `max_budget`, `budget_duration` for these auto-created teams.
+
+**How it works**
+
+1. When litellm fetches `groups` from your SSO provider, it will check if the corresponding group_id exists as a `team_id` in litellm.
+2. If the team_id does not exist, litellm will auto-create a team with the default params you've set.
+3. If the team_id already exist, litellm will not apply any settings on the team.
+
+**Usage**
+
+```yaml showLineNumbers title="Default Params for new teams"
+litellm_settings:
+ default_team_params: # Default Params to apply when litellm auto creates a team from SSO IDP provider
+ max_budget: 100 # Optional[float], optional): $100 budget for the team
+ budget_duration: 30d # Optional[str], optional): 30 days budget_duration for the team
+ models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by the team
+```
+
### Restrict Users from creating personal keys
@@ -290,7 +315,7 @@ This will also prevent users from using their session tokens on the test keys ch
## **All Settings for Self Serve / SSO Flow**
-```yaml
+```yaml showLineNumbers title="All Settings for Self Serve / SSO Flow"
litellm_settings:
max_internal_user_budget: 10 # max budget for internal users
internal_user_budget_duration: "1mo" # reset every month
@@ -300,6 +325,11 @@ litellm_settings:
max_budget: 100 # Optional[float], optional): $100 budget for a new SSO sign in user
budget_duration: 30d # Optional[str], optional): 30 days budget_duration for a new SSO sign in user
models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by a new SSO sign in user
+
+ default_team_params: # Default Params to apply when litellm auto creates a team from SSO IDP provider
+ max_budget: 100 # Optional[float], optional): $100 budget for the team
+ budget_duration: 30d # Optional[str], optional): 30 days budget_duration for the team
+ models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by the team
upperbound_key_generate_params: # Upperbound for /key/generate requests when self-serve flow is on
diff --git a/litellm/__init__.py b/litellm/__init__.py
index e061643398..a3b37da2b4 100644
--- a/litellm/__init__.py
+++ b/litellm/__init__.py
@@ -65,6 +65,7 @@ from litellm.proxy._types import (
KeyManagementSystem,
KeyManagementSettings,
LiteLLM_UpperboundKeyGenerateParams,
+ NewTeamRequest,
)
from litellm.types.utils import StandardKeyGenerationConfig, LlmProviders
from litellm.integrations.custom_logger import CustomLogger
@@ -126,19 +127,19 @@ prometheus_initialize_budget_metrics: Optional[bool] = False
require_auth_for_metrics_endpoint: Optional[bool] = False
argilla_batch_size: Optional[int] = None
datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload
-gcs_pub_sub_use_v1: Optional[
- bool
-] = False # if you want to use v1 gcs pubsub logged payload
+gcs_pub_sub_use_v1: Optional[bool] = (
+ False # if you want to use v1 gcs pubsub logged payload
+)
argilla_transformation_object: Optional[Dict[str, Any]] = None
-_async_input_callback: List[
- Union[str, Callable, CustomLogger]
-] = [] # internal variable - async custom callbacks are routed here.
-_async_success_callback: List[
- Union[str, Callable, CustomLogger]
-] = [] # internal variable - async custom callbacks are routed here.
-_async_failure_callback: List[
- Union[str, Callable, CustomLogger]
-] = [] # internal variable - async custom callbacks are routed here.
+_async_input_callback: List[Union[str, Callable, CustomLogger]] = (
+ []
+) # internal variable - async custom callbacks are routed here.
+_async_success_callback: List[Union[str, Callable, CustomLogger]] = (
+ []
+) # internal variable - async custom callbacks are routed here.
+_async_failure_callback: List[Union[str, Callable, CustomLogger]] = (
+ []
+) # internal variable - async custom callbacks are routed here.
pre_call_rules: List[Callable] = []
post_call_rules: List[Callable] = []
turn_off_message_logging: Optional[bool] = False
@@ -146,18 +147,18 @@ log_raw_request_response: bool = False
redact_messages_in_exceptions: Optional[bool] = False
redact_user_api_key_info: Optional[bool] = False
filter_invalid_headers: Optional[bool] = False
-add_user_information_to_llm_headers: Optional[
- bool
-] = None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers
+add_user_information_to_llm_headers: Optional[bool] = (
+ None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers
+)
store_audit_logs = False # Enterprise feature, allow users to see audit logs
### end of callbacks #############
-email: Optional[
- str
-] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
-token: Optional[
- str
-] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
+email: Optional[str] = (
+ None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
+)
+token: Optional[str] = (
+ None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
+)
telemetry = True
max_tokens: int = DEFAULT_MAX_TOKENS # OpenAI Defaults
drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False))
@@ -233,20 +234,24 @@ enable_loadbalancing_on_batch_endpoints: Optional[bool] = None
enable_caching_on_provider_specific_optional_params: bool = (
False # feature-flag for caching on optional params - e.g. 'top_k'
)
-caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
-caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
-cache: Optional[
- Cache
-] = None # cache object <- use this - https://docs.litellm.ai/docs/caching
+caching: bool = (
+ False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
+)
+caching_with_models: bool = (
+ False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
+)
+cache: Optional[Cache] = (
+ None # cache object <- use this - https://docs.litellm.ai/docs/caching
+)
default_in_memory_ttl: Optional[float] = None
default_redis_ttl: Optional[float] = None
default_redis_batch_cache_expiry: Optional[float] = None
model_alias_map: Dict[str, str] = {}
model_group_alias_map: Dict[str, str] = {}
max_budget: float = 0.0 # set the max budget across all providers
-budget_duration: Optional[
- str
-] = None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
+budget_duration: Optional[str] = (
+ None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
+)
default_soft_budget: float = (
DEFAULT_SOFT_BUDGET # by default all litellm proxy keys have a soft budget of 50.0
)
@@ -255,11 +260,15 @@ forward_traceparent_to_llm_provider: bool = False
_current_cost = 0.0 # private variable, used if max budget is set
error_logs: Dict = {}
-add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt
+add_function_to_prompt: bool = (
+ False # if function calling not supported by api, append function call details to system prompt
+)
client_session: Optional[httpx.Client] = None
aclient_session: Optional[httpx.AsyncClient] = None
model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
-model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
+model_cost_map_url: str = (
+ "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
+)
suppress_debug_info = False
dynamodb_table_name: Optional[str] = None
s3_callback_params: Optional[Dict] = None
@@ -268,6 +277,7 @@ default_key_generate_params: Optional[Dict] = None
upperbound_key_generate_params: Optional[LiteLLM_UpperboundKeyGenerateParams] = None
key_generation_settings: Optional[StandardKeyGenerationConfig] = None
default_internal_user_params: Optional[Dict] = None
+default_team_params: Optional[NewTeamRequest] = None
default_team_settings: Optional[List] = None
max_user_budget: Optional[float] = None
default_max_internal_user_budget: Optional[float] = None
@@ -281,7 +291,9 @@ disable_end_user_cost_tracking_prometheus_only: Optional[bool] = None
custom_prometheus_metadata_labels: List[str] = []
#### REQUEST PRIORITIZATION ####
priority_reservation: Optional[Dict[str, float]] = None
-force_ipv4: bool = False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6.
+force_ipv4: bool = (
+ False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6.
+)
module_level_aclient = AsyncHTTPHandler(
timeout=request_timeout, client_alias="module level aclient"
)
@@ -295,13 +307,13 @@ fallbacks: Optional[List] = None
context_window_fallbacks: Optional[List] = None
content_policy_fallbacks: Optional[List] = None
allowed_fails: int = 3
-num_retries_per_request: Optional[
- int
-] = None # for the request overall (incl. fallbacks + model retries)
+num_retries_per_request: Optional[int] = (
+ None # for the request overall (incl. fallbacks + model retries)
+)
####### SECRET MANAGERS #####################
-secret_manager_client: Optional[
- Any
-] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
+secret_manager_client: Optional[Any] = (
+ None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
+)
_google_kms_resource_name: Optional[str] = None
_key_management_system: Optional[KeyManagementSystem] = None
_key_management_settings: KeyManagementSettings = KeyManagementSettings()
@@ -1050,10 +1062,10 @@ from .types.llms.custom_llm import CustomLLMItem
from .types.utils import GenericStreamingChunk
custom_provider_map: List[CustomLLMItem] = []
-_custom_providers: List[
- str
-] = [] # internal helper util, used to track names of custom providers
-disable_hf_tokenizer_download: Optional[
- bool
-] = None # disable huggingface tokenizer download. Defaults to openai clk100
+_custom_providers: List[str] = (
+ []
+) # internal helper util, used to track names of custom providers
+disable_hf_tokenizer_download: Optional[bool] = (
+ None # disable huggingface tokenizer download. Defaults to openai clk100
+)
global_disable_no_log_param: bool = False
diff --git a/litellm/constants.py b/litellm/constants.py
index c8248f548a..12bfd17815 100644
--- a/litellm/constants.py
+++ b/litellm/constants.py
@@ -480,6 +480,7 @@ RESPONSE_FORMAT_TOOL_NAME = "json_tool_call" # default tool name used when conv
########################### Logging Callback Constants ###########################
AZURE_STORAGE_MSFT_VERSION = "2019-07-07"
+PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES = 5
MCP_TOOL_NAME_PREFIX = "mcp_tool"
########################### LiteLLM Proxy Specific Constants ###########################
@@ -514,6 +515,7 @@ LITELLM_PROXY_ADMIN_NAME = "default_user_id"
########################### DB CRON JOB NAMES ###########################
DB_SPEND_UPDATE_JOB_NAME = "db_spend_update_job"
+PROMETHEUS_EMIT_BUDGET_METRICS_JOB_NAME = "prometheus_emit_budget_metrics_job"
DEFAULT_CRON_JOB_LOCK_TTL_SECONDS = 60 # 1 minute
PROXY_BUDGET_RESCHEDULER_MIN_TIME = 597
PROXY_BUDGET_RESCHEDULER_MAX_TIME = 605
diff --git a/litellm/integrations/prometheus.py b/litellm/integrations/prometheus.py
index 6fba69d005..f61321e53d 100644
--- a/litellm/integrations/prometheus.py
+++ b/litellm/integrations/prometheus.py
@@ -1,10 +1,19 @@
# used for /metrics endpoint on LiteLLM Proxy
#### What this does ####
# On success, log events to Prometheus
-import asyncio
import sys
from datetime import datetime, timedelta
-from typing import Any, Awaitable, Callable, List, Literal, Optional, Tuple, cast
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Callable,
+ List,
+ Literal,
+ Optional,
+ Tuple,
+ cast,
+)
import litellm
from litellm._logging import print_verbose, verbose_logger
@@ -14,6 +23,11 @@ from litellm.types.integrations.prometheus import *
from litellm.types.utils import StandardLoggingPayload
from litellm.utils import get_end_user_id_for_cost_tracking
+if TYPE_CHECKING:
+ from apscheduler.schedulers.asyncio import AsyncIOScheduler
+else:
+ AsyncIOScheduler = Any
+
class PrometheusLogger(CustomLogger):
# Class variables or attributes
@@ -359,8 +373,6 @@ class PrometheusLogger(CustomLogger):
label_name="litellm_requests_metric"
),
)
- self._initialize_prometheus_startup_metrics()
-
except Exception as e:
print_verbose(f"Got exception on init prometheus client {str(e)}")
raise e
@@ -988,9 +1000,9 @@ class PrometheusLogger(CustomLogger):
):
try:
verbose_logger.debug("setting remaining tokens requests metric")
- standard_logging_payload: Optional[
- StandardLoggingPayload
- ] = request_kwargs.get("standard_logging_object")
+ standard_logging_payload: Optional[StandardLoggingPayload] = (
+ request_kwargs.get("standard_logging_object")
+ )
if standard_logging_payload is None:
return
@@ -1337,24 +1349,6 @@ class PrometheusLogger(CustomLogger):
return max_budget - spend
- def _initialize_prometheus_startup_metrics(self):
- """
- Initialize prometheus startup metrics
-
- Helper to create tasks for initializing metrics that are required on startup - eg. remaining budget metrics
- """
- if litellm.prometheus_initialize_budget_metrics is not True:
- verbose_logger.debug("Prometheus: skipping budget metrics initialization")
- return
-
- try:
- if asyncio.get_running_loop():
- asyncio.create_task(self._initialize_remaining_budget_metrics())
- except RuntimeError as e: # no running event loop
- verbose_logger.exception(
- f"No running event loop - skipping budget metrics initialization: {str(e)}"
- )
-
async def _initialize_budget_metrics(
self,
data_fetch_function: Callable[..., Awaitable[Tuple[List[Any], Optional[int]]]],
@@ -1475,12 +1469,41 @@ class PrometheusLogger(CustomLogger):
data_type="keys",
)
- async def _initialize_remaining_budget_metrics(self):
+ async def initialize_remaining_budget_metrics(self):
"""
- Initialize remaining budget metrics for all teams to avoid metric discrepancies.
+ Handler for initializing remaining budget metrics for all teams to avoid metric discrepancies.
Runs when prometheus logger starts up.
+
+ - If redis cache is available, we use the pod lock manager to acquire a lock and initialize the metrics.
+ - Ensures only one pod emits the metrics at a time.
+ - If redis cache is not available, we initialize the metrics directly.
"""
+ from litellm.constants import PROMETHEUS_EMIT_BUDGET_METRICS_JOB_NAME
+ from litellm.proxy.proxy_server import proxy_logging_obj
+
+ pod_lock_manager = proxy_logging_obj.db_spend_update_writer.pod_lock_manager
+
+ # if using redis, ensure only one pod emits the metrics at a time
+ if pod_lock_manager and pod_lock_manager.redis_cache:
+ if await pod_lock_manager.acquire_lock(
+ cronjob_id=PROMETHEUS_EMIT_BUDGET_METRICS_JOB_NAME
+ ):
+ try:
+ await self._initialize_remaining_budget_metrics()
+ finally:
+ await pod_lock_manager.release_lock(
+ cronjob_id=PROMETHEUS_EMIT_BUDGET_METRICS_JOB_NAME
+ )
+ else:
+ # if not using redis, initialize the metrics directly
+ await self._initialize_remaining_budget_metrics()
+
+ async def _initialize_remaining_budget_metrics(self):
+ """
+ Helper to initialize remaining budget metrics for all teams and API keys.
+ """
+ verbose_logger.debug("Emitting key, team budget metrics....")
await self._initialize_team_budget_metrics()
await self._initialize_api_key_budget_metrics()
@@ -1737,6 +1760,36 @@ class PrometheusLogger(CustomLogger):
return (end_time - start_time).total_seconds()
return None
+ @staticmethod
+ def initialize_budget_metrics_cron_job(scheduler: AsyncIOScheduler):
+ """
+ Initialize budget metrics as a cron job. This job runs every `PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES` minutes.
+
+ It emits the current remaining budget metrics for all Keys and Teams.
+ """
+ from litellm.constants import PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES
+ from litellm.integrations.custom_logger import CustomLogger
+ from litellm.integrations.prometheus import PrometheusLogger
+
+ prometheus_loggers: List[CustomLogger] = (
+ litellm.logging_callback_manager.get_custom_loggers_for_type(
+ callback_type=PrometheusLogger
+ )
+ )
+ # we need to get the initialized prometheus logger instance(s) and call logger.initialize_remaining_budget_metrics() on them
+ verbose_logger.debug("found %s prometheus loggers", len(prometheus_loggers))
+ if len(prometheus_loggers) > 0:
+ prometheus_logger = cast(PrometheusLogger, prometheus_loggers[0])
+ verbose_logger.debug(
+ "Initializing remaining budget metrics as a cron job executing every %s minutes"
+ % PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES
+ )
+ scheduler.add_job(
+ prometheus_logger.initialize_remaining_budget_metrics,
+ "interval",
+ minutes=PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES,
+ )
+
@staticmethod
def _mount_metrics_endpoint(premium_user: bool):
"""
diff --git a/litellm/proxy/db/db_spend_update_writer.py b/litellm/proxy/db/db_spend_update_writer.py
index b32dc5c691..12ae51822c 100644
--- a/litellm/proxy/db/db_spend_update_writer.py
+++ b/litellm/proxy/db/db_spend_update_writer.py
@@ -53,7 +53,7 @@ class DBSpendUpdateWriter:
):
self.redis_cache = redis_cache
self.redis_update_buffer = RedisUpdateBuffer(redis_cache=self.redis_cache)
- self.pod_lock_manager = PodLockManager(cronjob_id=DB_SPEND_UPDATE_JOB_NAME)
+ self.pod_lock_manager = PodLockManager()
self.spend_update_queue = SpendUpdateQueue()
self.daily_spend_update_queue = DailySpendUpdateQueue()
@@ -383,7 +383,9 @@ class DBSpendUpdateWriter:
)
# Only commit from redis to db if this pod is the leader
- if await self.pod_lock_manager.acquire_lock():
+ if await self.pod_lock_manager.acquire_lock(
+ cronjob_id=DB_SPEND_UPDATE_JOB_NAME,
+ ):
verbose_proxy_logger.debug("acquired lock for spend updates")
try:
@@ -411,7 +413,9 @@ class DBSpendUpdateWriter:
except Exception as e:
verbose_proxy_logger.error(f"Error committing spend updates: {e}")
finally:
- await self.pod_lock_manager.release_lock()
+ await self.pod_lock_manager.release_lock(
+ cronjob_id=DB_SPEND_UPDATE_JOB_NAME,
+ )
async def _commit_spend_updates_to_db_without_redis_buffer(
self,
diff --git a/litellm/proxy/db/db_transaction_queue/pod_lock_manager.py b/litellm/proxy/db/db_transaction_queue/pod_lock_manager.py
index cb4a43a802..be3be64546 100644
--- a/litellm/proxy/db/db_transaction_queue/pod_lock_manager.py
+++ b/litellm/proxy/db/db_transaction_queue/pod_lock_manager.py
@@ -21,18 +21,18 @@ class PodLockManager:
Ensures that only one pod can run a cron job at a time.
"""
- def __init__(self, cronjob_id: str, redis_cache: Optional[RedisCache] = None):
+ def __init__(self, redis_cache: Optional[RedisCache] = None):
self.pod_id = str(uuid.uuid4())
- self.cronjob_id = cronjob_id
self.redis_cache = redis_cache
- # Define a unique key for this cronjob lock in Redis.
- self.lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
@staticmethod
def get_redis_lock_key(cronjob_id: str) -> str:
return f"cronjob_lock:{cronjob_id}"
- async def acquire_lock(self) -> Optional[bool]:
+ async def acquire_lock(
+ self,
+ cronjob_id: str,
+ ) -> Optional[bool]:
"""
Attempt to acquire the lock for a specific cron job using Redis.
Uses the SET command with NX and EX options to ensure atomicity.
@@ -44,12 +44,13 @@ class PodLockManager:
verbose_proxy_logger.debug(
"Pod %s attempting to acquire Redis lock for cronjob_id=%s",
self.pod_id,
- self.cronjob_id,
+ cronjob_id,
)
# Try to set the lock key with the pod_id as its value, only if it doesn't exist (NX)
# and with an expiration (EX) to avoid deadlocks.
+ lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
acquired = await self.redis_cache.async_set_cache(
- self.lock_key,
+ lock_key,
self.pod_id,
nx=True,
ttl=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
@@ -58,13 +59,13 @@ class PodLockManager:
verbose_proxy_logger.info(
"Pod %s successfully acquired Redis lock for cronjob_id=%s",
self.pod_id,
- self.cronjob_id,
+ cronjob_id,
)
return True
else:
# Check if the current pod already holds the lock
- current_value = await self.redis_cache.async_get_cache(self.lock_key)
+ current_value = await self.redis_cache.async_get_cache(lock_key)
if current_value is not None:
if isinstance(current_value, bytes):
current_value = current_value.decode("utf-8")
@@ -72,18 +73,21 @@ class PodLockManager:
verbose_proxy_logger.info(
"Pod %s already holds the Redis lock for cronjob_id=%s",
self.pod_id,
- self.cronjob_id,
+ cronjob_id,
)
- self._emit_acquired_lock_event(self.cronjob_id, self.pod_id)
+ self._emit_acquired_lock_event(cronjob_id, self.pod_id)
return True
return False
except Exception as e:
verbose_proxy_logger.error(
- f"Error acquiring Redis lock for {self.cronjob_id}: {e}"
+ f"Error acquiring Redis lock for {cronjob_id}: {e}"
)
return False
- async def release_lock(self):
+ async def release_lock(
+ self,
+ cronjob_id: str,
+ ):
"""
Release the lock if the current pod holds it.
Uses get and delete commands to ensure that only the owner can release the lock.
@@ -92,46 +96,52 @@ class PodLockManager:
verbose_proxy_logger.debug("redis_cache is None, skipping release_lock")
return
try:
+ cronjob_id = cronjob_id
verbose_proxy_logger.debug(
"Pod %s attempting to release Redis lock for cronjob_id=%s",
self.pod_id,
- self.cronjob_id,
+ cronjob_id,
)
- current_value = await self.redis_cache.async_get_cache(self.lock_key)
+ lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
+
+ current_value = await self.redis_cache.async_get_cache(lock_key)
if current_value is not None:
if isinstance(current_value, bytes):
current_value = current_value.decode("utf-8")
if current_value == self.pod_id:
- result = await self.redis_cache.async_delete_cache(self.lock_key)
+ result = await self.redis_cache.async_delete_cache(lock_key)
if result == 1:
verbose_proxy_logger.info(
"Pod %s successfully released Redis lock for cronjob_id=%s",
self.pod_id,
- self.cronjob_id,
+ cronjob_id,
+ )
+ self._emit_released_lock_event(
+ cronjob_id=cronjob_id,
+ pod_id=self.pod_id,
)
- self._emit_released_lock_event(self.cronjob_id, self.pod_id)
else:
verbose_proxy_logger.debug(
"Pod %s failed to release Redis lock for cronjob_id=%s",
self.pod_id,
- self.cronjob_id,
+ cronjob_id,
)
else:
verbose_proxy_logger.debug(
"Pod %s cannot release Redis lock for cronjob_id=%s because it is held by pod %s",
self.pod_id,
- self.cronjob_id,
+ cronjob_id,
current_value,
)
else:
verbose_proxy_logger.debug(
"Pod %s attempted to release Redis lock for cronjob_id=%s, but no lock was found",
self.pod_id,
- self.cronjob_id,
+ cronjob_id,
)
except Exception as e:
verbose_proxy_logger.error(
- f"Error releasing Redis lock for {self.cronjob_id}: {e}"
+ f"Error releasing Redis lock for {cronjob_id}: {e}"
)
@staticmethod
diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py
index 9ead84ab7e..0365336e73 100644
--- a/litellm/proxy/management_endpoints/ui_sso.py
+++ b/litellm/proxy/management_endpoints/ui_sso.py
@@ -876,6 +876,68 @@ class SSOAuthenticationHandler:
sso_teams = getattr(result, "team_ids", [])
await add_missing_team_member(user_info=user_info, sso_teams=sso_teams)
+ @staticmethod
+ async def create_litellm_team_from_sso_group(
+ litellm_team_id: str,
+ litellm_team_name: Optional[str] = None,
+ ):
+ """
+ Creates a Litellm Team from a SSO Group ID
+
+ Your SSO provider might have groups that should be created on LiteLLM
+
+ Use this helper to create a Litellm Team from a SSO Group ID
+
+ Args:
+ litellm_team_id (str): The ID of the Litellm Team
+ litellm_team_name (Optional[str]): The name of the Litellm Team
+ """
+ from litellm.proxy.proxy_server import prisma_client
+
+ if prisma_client is None:
+ raise ProxyException(
+ message="Prisma client not found. Set it in the proxy_server.py file",
+ type=ProxyErrorTypes.auth_error,
+ param="prisma_client",
+ code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ )
+ try:
+ team_obj = await prisma_client.db.litellm_teamtable.find_first(
+ where={"team_id": litellm_team_id}
+ )
+ verbose_proxy_logger.debug(f"Team object: {team_obj}")
+
+ # only create a new team if it doesn't exist
+ if team_obj:
+ verbose_proxy_logger.debug(
+ f"Team already exists: {litellm_team_id} - {litellm_team_name}"
+ )
+ return
+
+ team_request: NewTeamRequest = NewTeamRequest(
+ team_id=litellm_team_id,
+ team_alias=litellm_team_name,
+ )
+ if litellm.default_team_params:
+ team_request = litellm.default_team_params.model_copy(
+ deep=True,
+ update={
+ "team_id": litellm_team_id,
+ "team_alias": litellm_team_name,
+ },
+ )
+ await new_team(
+ data=team_request,
+ # params used for Audit Logging
+ http_request=Request(scope={"type": "http", "method": "POST"}),
+ user_api_key_dict=UserAPIKeyAuth(
+ token="",
+ key_alias=f"litellm.{MicrosoftSSOHandler.__name__}",
+ ),
+ )
+ except Exception as e:
+ verbose_proxy_logger.exception(f"Error creating Litellm Team: {e}")
+
class MicrosoftSSOHandler:
"""
@@ -1156,15 +1218,6 @@ class MicrosoftSSOHandler:
When a user sets a `SERVICE_PRINCIPAL_ID` in the env, litellm will fetch groups under that service principal and create Litellm Teams from them
"""
- from litellm.proxy.proxy_server import prisma_client
-
- if prisma_client is None:
- raise ProxyException(
- message="Prisma client not found. Set it in the proxy_server.py file",
- type=ProxyErrorTypes.auth_error,
- param="prisma_client",
- code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- )
verbose_proxy_logger.debug(
f"Creating Litellm Teams from Service Principal Teams: {service_principal_teams}"
)
@@ -1179,36 +1232,10 @@ class MicrosoftSSOHandler:
)
continue
- try:
- verbose_proxy_logger.debug(
- f"Creating Litellm Team: {litellm_team_id} - {litellm_team_name}"
- )
-
- team_obj = await prisma_client.db.litellm_teamtable.find_first(
- where={"team_id": litellm_team_id}
- )
- verbose_proxy_logger.debug(f"Team object: {team_obj}")
-
- # only create a new team if it doesn't exist
- if team_obj:
- verbose_proxy_logger.debug(
- f"Team already exists: {litellm_team_id} - {litellm_team_name}"
- )
- continue
- await new_team(
- data=NewTeamRequest(
- team_id=litellm_team_id,
- team_alias=litellm_team_name,
- ),
- # params used for Audit Logging
- http_request=Request(scope={"type": "http", "method": "POST"}),
- user_api_key_dict=UserAPIKeyAuth(
- token="",
- key_alias=f"litellm.{MicrosoftSSOHandler.__name__}",
- ),
- )
- except Exception as e:
- verbose_proxy_logger.exception(f"Error creating Litellm Team: {e}")
+ await SSOAuthenticationHandler.create_litellm_team_from_sso_group(
+ litellm_team_id=litellm_team_id,
+ litellm_team_name=litellm_team_name,
+ )
class GoogleSSOHandler:
diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml
index 23de923db7..847ca7ce56 100644
--- a/litellm/proxy/proxy_config.yaml
+++ b/litellm/proxy/proxy_config.yaml
@@ -10,6 +10,36 @@ model_list:
api_key: fake-key
litellm_settings:
+ prometheus_initialize_budget_metrics: true
+ callbacks: ["prometheus"]
+
+mcp_tools:
+ - name: "get_current_time"
+ description: "Get the current time"
+ input_schema: {
+ "type": "object",
+ "properties": {
+ "format": {
+ "type": "string",
+ "description": "The format of the time to return",
+ "enum": ["short"]
+ }
+ }
+ }
+ handler: "mcp_tools.get_current_time"
+ - name: "get_current_date"
+ description: "Get the current date"
+ input_schema: {
+ "type": "object",
+ "properties": {
+ "format": {
+ "type": "string",
+ "description": "The format of the date to return",
+ "enum": ["short"]
+ }
+ }
+ }
+ handler: "mcp_tools.get_current_date"
default_team_settings:
- team_id: test_dev
success_callback: ["langfuse", "s3"]
diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py
index ddfb7118d7..84b515f405 100644
--- a/litellm/proxy/proxy_server.py
+++ b/litellm/proxy/proxy_server.py
@@ -803,9 +803,9 @@ model_max_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter(
dual_cache=user_api_key_cache
)
litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter)
-redis_usage_cache: Optional[
- RedisCache
-] = None # redis cache used for tracking spend, tpm/rpm limits
+redis_usage_cache: Optional[RedisCache] = (
+ None # redis cache used for tracking spend, tpm/rpm limits
+)
user_custom_auth = None
user_custom_key_generate = None
user_custom_sso = None
@@ -1131,9 +1131,9 @@ async def update_cache( # noqa: PLR0915
_id = "team_id:{}".format(team_id)
try:
# Fetch the existing cost for the given user
- existing_spend_obj: Optional[
- LiteLLM_TeamTable
- ] = await user_api_key_cache.async_get_cache(key=_id)
+ existing_spend_obj: Optional[LiteLLM_TeamTable] = (
+ await user_api_key_cache.async_get_cache(key=_id)
+ )
if existing_spend_obj is None:
# do nothing if team not in api key cache
return
@@ -2812,9 +2812,9 @@ async def initialize( # noqa: PLR0915
user_api_base = api_base
dynamic_config[user_model]["api_base"] = api_base
if api_version:
- os.environ[
- "AZURE_API_VERSION"
- ] = api_version # set this for azure - litellm can read this from the env
+ os.environ["AZURE_API_VERSION"] = (
+ api_version # set this for azure - litellm can read this from the env
+ )
if max_tokens: # model-specific param
dynamic_config[user_model]["max_tokens"] = max_tokens
if temperature: # model-specific param
@@ -3191,6 +3191,11 @@ class ProxyStartupEvent:
)
await proxy_logging_obj.slack_alerting_instance.send_fallback_stats_from_prometheus()
+ if litellm.prometheus_initialize_budget_metrics is True:
+ from litellm.integrations.prometheus import PrometheusLogger
+
+ PrometheusLogger.initialize_budget_metrics_cron_job(scheduler=scheduler)
+
scheduler.start()
@classmethod
@@ -7753,9 +7758,9 @@ async def get_config_list(
hasattr(sub_field_info, "description")
and sub_field_info.description is not None
):
- nested_fields[
- idx
- ].field_description = sub_field_info.description
+ nested_fields[idx].field_description = (
+ sub_field_info.description
+ )
idx += 1
_stored_in_db = None
diff --git a/tests/litellm/integrations/test_prometheus.py b/tests/litellm/integrations/test_prometheus.py
new file mode 100644
index 0000000000..464477f019
--- /dev/null
+++ b/tests/litellm/integrations/test_prometheus.py
@@ -0,0 +1,44 @@
+"""
+Mock prometheus unit tests, these don't rely on LLM API calls
+"""
+
+import json
+import os
+import sys
+
+import pytest
+from fastapi.testclient import TestClient
+
+sys.path.insert(
+ 0, os.path.abspath("../../..")
+) # Adds the parent directory to the system path
+
+from apscheduler.schedulers.asyncio import AsyncIOScheduler
+
+import litellm
+from litellm.constants import PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES
+from litellm.integrations.prometheus import PrometheusLogger
+
+
+def test_initialize_budget_metrics_cron_job():
+ # Create a scheduler
+ scheduler = AsyncIOScheduler()
+
+ # Create and register a PrometheusLogger
+ prometheus_logger = PrometheusLogger()
+ litellm.callbacks = [prometheus_logger]
+
+ # Initialize the cron job
+ PrometheusLogger.initialize_budget_metrics_cron_job(scheduler)
+
+ # Verify that a job was added to the scheduler
+ jobs = scheduler.get_jobs()
+ assert len(jobs) == 1
+
+ # Verify job properties
+ job = jobs[0]
+ assert (
+ job.trigger.interval.total_seconds() / 60
+ == PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES
+ )
+ assert job.func.__name__ == "initialize_remaining_budget_metrics"
diff --git a/tests/litellm/proxy/db/db_transaction_queue/test_pod_lock_manager.py b/tests/litellm/proxy/db/db_transaction_queue/test_pod_lock_manager.py
index 697d985dc9..e83fd75c3a 100644
--- a/tests/litellm/proxy/db/db_transaction_queue/test_pod_lock_manager.py
+++ b/tests/litellm/proxy/db/db_transaction_queue/test_pod_lock_manager.py
@@ -29,7 +29,7 @@ def mock_redis():
@pytest.fixture
def pod_lock_manager(mock_redis):
- return PodLockManager(cronjob_id="test_job", redis_cache=mock_redis)
+ return PodLockManager(redis_cache=mock_redis)
@pytest.mark.asyncio
@@ -40,12 +40,15 @@ async def test_acquire_lock_success(pod_lock_manager, mock_redis):
# Mock successful acquisition (SET NX returns True)
mock_redis.async_set_cache.return_value = True
- result = await pod_lock_manager.acquire_lock()
+ result = await pod_lock_manager.acquire_lock(
+ cronjob_id="test_job",
+ )
assert result == True
# Verify set_cache was called with correct parameters
+ lock_key = pod_lock_manager.get_redis_lock_key(cronjob_id="test_job")
mock_redis.async_set_cache.assert_called_once_with(
- pod_lock_manager.lock_key,
+ lock_key,
pod_lock_manager.pod_id,
nx=True,
ttl=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
@@ -62,13 +65,16 @@ async def test_acquire_lock_existing_active(pod_lock_manager, mock_redis):
# Mock get_cache to return a different pod's ID
mock_redis.async_get_cache.return_value = "different_pod_id"
- result = await pod_lock_manager.acquire_lock()
+ result = await pod_lock_manager.acquire_lock(
+ cronjob_id="test_job",
+ )
assert result == False
# Verify set_cache was called
mock_redis.async_set_cache.assert_called_once()
# Verify get_cache was called to check existing lock
- mock_redis.async_get_cache.assert_called_once_with(pod_lock_manager.lock_key)
+ lock_key = pod_lock_manager.get_redis_lock_key(cronjob_id="test_job")
+ mock_redis.async_get_cache.assert_called_once_with(lock_key)
@pytest.mark.asyncio
@@ -89,7 +95,9 @@ async def test_acquire_lock_expired(pod_lock_manager, mock_redis):
# Then set succeeds on retry (simulating key expiring between checks)
mock_redis.async_set_cache.side_effect = [False, True]
- result = await pod_lock_manager.acquire_lock()
+ result = await pod_lock_manager.acquire_lock(
+ cronjob_id="test_job",
+ )
assert result == False # First attempt fails
# Reset mock for a second attempt
@@ -97,7 +105,9 @@ async def test_acquire_lock_expired(pod_lock_manager, mock_redis):
mock_redis.async_set_cache.return_value = True
# Try again (simulating the lock expired)
- result = await pod_lock_manager.acquire_lock()
+ result = await pod_lock_manager.acquire_lock(
+ cronjob_id="test_job",
+ )
assert result == True
# Verify set_cache was called again
@@ -114,12 +124,15 @@ async def test_release_lock_success(pod_lock_manager, mock_redis):
# Mock successful deletion
mock_redis.async_delete_cache.return_value = 1
- await pod_lock_manager.release_lock()
+ await pod_lock_manager.release_lock(
+ cronjob_id="test_job",
+ )
# Verify get_cache was called
- mock_redis.async_get_cache.assert_called_once_with(pod_lock_manager.lock_key)
+ lock_key = pod_lock_manager.get_redis_lock_key(cronjob_id="test_job")
+ mock_redis.async_get_cache.assert_called_once_with(lock_key)
# Verify delete_cache was called
- mock_redis.async_delete_cache.assert_called_once_with(pod_lock_manager.lock_key)
+ mock_redis.async_delete_cache.assert_called_once_with(lock_key)
@pytest.mark.asyncio
@@ -130,10 +143,13 @@ async def test_release_lock_different_pod(pod_lock_manager, mock_redis):
# Mock get_cache to return a different pod's ID
mock_redis.async_get_cache.return_value = "different_pod_id"
- await pod_lock_manager.release_lock()
+ await pod_lock_manager.release_lock(
+ cronjob_id="test_job",
+ )
# Verify get_cache was called
- mock_redis.async_get_cache.assert_called_once_with(pod_lock_manager.lock_key)
+ lock_key = pod_lock_manager.get_redis_lock_key(cronjob_id="test_job")
+ mock_redis.async_get_cache.assert_called_once_with(lock_key)
# Verify delete_cache was NOT called
mock_redis.async_delete_cache.assert_not_called()
@@ -146,10 +162,13 @@ async def test_release_lock_no_lock(pod_lock_manager, mock_redis):
# Mock get_cache to return None (no lock)
mock_redis.async_get_cache.return_value = None
- await pod_lock_manager.release_lock()
+ await pod_lock_manager.release_lock(
+ cronjob_id="test_job",
+ )
# Verify get_cache was called
- mock_redis.async_get_cache.assert_called_once_with(pod_lock_manager.lock_key)
+ lock_key = pod_lock_manager.get_redis_lock_key(cronjob_id="test_job")
+ mock_redis.async_get_cache.assert_called_once_with(lock_key)
# Verify delete_cache was NOT called
mock_redis.async_delete_cache.assert_not_called()
@@ -159,13 +178,20 @@ async def test_redis_none(monkeypatch):
"""
Test behavior when redis_cache is None
"""
- pod_lock_manager = PodLockManager(cronjob_id="test_job", redis_cache=None)
+ pod_lock_manager = PodLockManager(redis_cache=None)
# Test acquire_lock with None redis_cache
- assert await pod_lock_manager.acquire_lock() is None
+ assert (
+ await pod_lock_manager.acquire_lock(
+ cronjob_id="test_job",
+ )
+ is None
+ )
# Test release_lock with None redis_cache (should not raise exception)
- await pod_lock_manager.release_lock()
+ await pod_lock_manager.release_lock(
+ cronjob_id="test_job",
+ )
@pytest.mark.asyncio
@@ -179,7 +205,9 @@ async def test_redis_error_handling(pod_lock_manager, mock_redis):
mock_redis.async_delete_cache.side_effect = Exception("Redis error")
# Test acquire_lock error handling
- result = await pod_lock_manager.acquire_lock()
+ result = await pod_lock_manager.acquire_lock(
+ cronjob_id="test_job",
+ )
assert result == False
# Reset side effect for get_cache for the release test
@@ -187,7 +215,9 @@ async def test_redis_error_handling(pod_lock_manager, mock_redis):
mock_redis.async_get_cache.return_value = pod_lock_manager.pod_id
# Test release_lock error handling (should not raise exception)
- await pod_lock_manager.release_lock()
+ await pod_lock_manager.release_lock(
+ cronjob_id="test_job",
+ )
@pytest.mark.asyncio
@@ -200,14 +230,18 @@ async def test_bytes_handling(pod_lock_manager, mock_redis):
# Mock get_cache to return bytes
mock_redis.async_get_cache.return_value = pod_lock_manager.pod_id.encode("utf-8")
- result = await pod_lock_manager.acquire_lock()
+ result = await pod_lock_manager.acquire_lock(
+ cronjob_id="test_job",
+ )
assert result == True
# Reset for release test
mock_redis.async_get_cache.return_value = pod_lock_manager.pod_id.encode("utf-8")
mock_redis.async_delete_cache.return_value = 1
- await pod_lock_manager.release_lock()
+ await pod_lock_manager.release_lock(
+ cronjob_id="test_job",
+ )
mock_redis.async_delete_cache.assert_called_once()
@@ -217,15 +251,17 @@ async def test_concurrent_lock_acquisition_simulation():
Simulate multiple pods trying to acquire the lock simultaneously
"""
mock_redis = MockRedisCache()
- pod1 = PodLockManager(cronjob_id="test_job", redis_cache=mock_redis)
- pod2 = PodLockManager(cronjob_id="test_job", redis_cache=mock_redis)
- pod3 = PodLockManager(cronjob_id="test_job", redis_cache=mock_redis)
+ pod1 = PodLockManager(redis_cache=mock_redis)
+ pod2 = PodLockManager(redis_cache=mock_redis)
+ pod3 = PodLockManager(redis_cache=mock_redis)
# Simulate first pod getting the lock
mock_redis.async_set_cache.return_value = True
# First pod should get the lock
- result1 = await pod1.acquire_lock()
+ result1 = await pod1.acquire_lock(
+ cronjob_id="test_job",
+ )
assert result1 == True
# Simulate other pods failing to get the lock
@@ -233,8 +269,12 @@ async def test_concurrent_lock_acquisition_simulation():
mock_redis.async_get_cache.return_value = pod1.pod_id
# Other pods should fail to acquire
- result2 = await pod2.acquire_lock()
- result3 = await pod3.acquire_lock()
+ result2 = await pod2.acquire_lock(
+ cronjob_id="test_job",
+ )
+ result3 = await pod3.acquire_lock(
+ cronjob_id="test_job",
+ )
# Since other pods don't have the lock, they should get False
assert result2 == False
@@ -246,14 +286,16 @@ async def test_lock_takeover_race_condition(mock_redis):
"""
Test scenario where multiple pods try to take over an expired lock using Redis
"""
- pod1 = PodLockManager(cronjob_id="test_job", redis_cache=mock_redis)
- pod2 = PodLockManager(cronjob_id="test_job", redis_cache=mock_redis)
+ pod1 = PodLockManager(redis_cache=mock_redis)
+ pod2 = PodLockManager(redis_cache=mock_redis)
# Simulate first pod's acquisition succeeding
mock_redis.async_set_cache.return_value = True
# First pod should successfully acquire
- result1 = await pod1.acquire_lock()
+ result1 = await pod1.acquire_lock(
+ cronjob_id="test_job",
+ )
assert result1 == True
# Simulate race condition: second pod tries but fails
@@ -261,5 +303,7 @@ async def test_lock_takeover_race_condition(mock_redis):
mock_redis.async_get_cache.return_value = pod1.pod_id
# Second pod should fail to acquire
- result2 = await pod2.acquire_lock()
+ result2 = await pod2.acquire_lock(
+ cronjob_id="test_job",
+ )
assert result2 == False
diff --git a/tests/litellm/proxy/management_endpoints/test_ui_sso.py b/tests/litellm/proxy/management_endpoints/test_ui_sso.py
index 5b1a9582c7..09e337bf84 100644
--- a/tests/litellm/proxy/management_endpoints/test_ui_sso.py
+++ b/tests/litellm/proxy/management_endpoints/test_ui_sso.py
@@ -2,6 +2,7 @@ import asyncio
import json
import os
import sys
+import uuid
from typing import Optional, cast
from unittest.mock import AsyncMock, MagicMock, patch
@@ -13,6 +14,8 @@ sys.path.insert(
0, os.path.abspath("../../../")
) # Adds the parent directory to the system path
+import litellm
+from litellm.proxy._types import NewTeamRequest
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.management_endpoints.types import CustomOpenID
from litellm.proxy.management_endpoints.ui_sso import (
@@ -23,6 +26,7 @@ from litellm.proxy.management_endpoints.ui_sso import (
from litellm.types.proxy.management_endpoints.ui_sso import (
MicrosoftGraphAPIUserGroupDirectoryObject,
MicrosoftGraphAPIUserGroupResponse,
+ MicrosoftServicePrincipalTeam,
)
@@ -460,3 +464,95 @@ async def test_upsert_sso_user_existing_user():
data={"user_email": "new_email@example.com"},
)
assert result == mock_user
+
+
+async def test_default_team_params():
+ """
+ When litellm.default_team_params is set, it should be used to create a new team
+ """
+ # Arrange
+ litellm.default_team_params = NewTeamRequest(
+ max_budget=10, budget_duration="1d", models=["special-gpt-5"]
+ )
+
+ def mock_jsonify_team_object(db_data):
+ return db_data
+
+ # Mock Prisma client
+ mock_prisma = MagicMock()
+ mock_prisma.db.litellm_teamtable.find_first = AsyncMock(return_value=None)
+ mock_prisma.db.litellm_teamtable.create = AsyncMock()
+ mock_prisma.get_data = AsyncMock(return_value=None)
+ mock_prisma.jsonify_team_object = MagicMock(side_effect=mock_jsonify_team_object)
+
+ with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma):
+ # Act
+ team_id = str(uuid.uuid4())
+ await MicrosoftSSOHandler.create_litellm_teams_from_service_principal_team_ids(
+ service_principal_teams=[
+ MicrosoftServicePrincipalTeam(
+ principalId=team_id,
+ principalDisplayName="Test Team",
+ )
+ ]
+ )
+
+ # Assert
+ # Verify team was created with correct parameters
+ mock_prisma.db.litellm_teamtable.create.assert_called_once()
+ print(
+ "mock_prisma.db.litellm_teamtable.create.call_args",
+ mock_prisma.db.litellm_teamtable.create.call_args,
+ )
+ create_call_args = mock_prisma.db.litellm_teamtable.create.call_args.kwargs[
+ "data"
+ ]
+ assert create_call_args["team_id"] == team_id
+ assert create_call_args["team_alias"] == "Test Team"
+ assert create_call_args["max_budget"] == 10
+ assert create_call_args["budget_duration"] == "1d"
+ assert create_call_args["models"] == ["special-gpt-5"]
+
+
+@pytest.mark.asyncio
+async def test_create_team_without_default_params():
+ """
+ Test team creation when litellm.default_team_params is None
+ Should create team with just the basic required fields
+ """
+ # Arrange
+ litellm.default_team_params = None
+
+ def mock_jsonify_team_object(db_data):
+ return db_data
+
+ # Mock Prisma client
+ mock_prisma = MagicMock()
+ mock_prisma.db.litellm_teamtable.find_first = AsyncMock(return_value=None)
+ mock_prisma.db.litellm_teamtable.create = AsyncMock()
+ mock_prisma.get_data = AsyncMock(return_value=None)
+ mock_prisma.jsonify_team_object = MagicMock(side_effect=mock_jsonify_team_object)
+
+ with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma):
+ # Act
+ team_id = str(uuid.uuid4())
+ await MicrosoftSSOHandler.create_litellm_teams_from_service_principal_team_ids(
+ service_principal_teams=[
+ MicrosoftServicePrincipalTeam(
+ principalId=team_id,
+ principalDisplayName="Test Team",
+ )
+ ]
+ )
+
+ # Assert
+ mock_prisma.db.litellm_teamtable.create.assert_called_once()
+ create_call_args = mock_prisma.db.litellm_teamtable.create.call_args.kwargs[
+ "data"
+ ]
+ assert create_call_args["team_id"] == team_id
+ assert create_call_args["team_alias"] == "Test Team"
+ # Should not have any of the optional fields
+ assert "max_budget" not in create_call_args
+ assert "budget_duration" not in create_call_args
+ assert create_call_args["models"] == []
diff --git a/tests/logging_callback_tests/test_prometheus_unit_tests.py b/tests/logging_callback_tests/test_prometheus_unit_tests.py
index ddfce710d7..0b58bc7aaf 100644
--- a/tests/logging_callback_tests/test_prometheus_unit_tests.py
+++ b/tests/logging_callback_tests/test_prometheus_unit_tests.py
@@ -39,7 +39,7 @@ import time
@pytest.fixture
-def prometheus_logger():
+def prometheus_logger() -> PrometheusLogger:
collectors = list(REGISTRY._collector_to_names.keys())
for collector in collectors:
REGISTRY.unregister(collector)
@@ -1212,24 +1212,6 @@ async def test_initialize_remaining_budget_metrics_exception_handling(
prometheus_logger.litellm_remaining_api_key_budget_metric.assert_not_called()
-def test_initialize_prometheus_startup_metrics_no_loop(prometheus_logger):
- """
- Test that _initialize_prometheus_startup_metrics handles case when no event loop exists
- """
- # Mock asyncio.get_running_loop to raise RuntimeError
- litellm.prometheus_initialize_budget_metrics = True
- with patch(
- "asyncio.get_running_loop", side_effect=RuntimeError("No running event loop")
- ), patch("litellm._logging.verbose_logger.exception") as mock_logger:
-
- # Call the function
- prometheus_logger._initialize_prometheus_startup_metrics()
-
- # Verify the error was logged
- mock_logger.assert_called_once()
- assert "No running event loop" in mock_logger.call_args[0][0]
-
-
@pytest.mark.asyncio(scope="session")
async def test_initialize_api_key_budget_metrics(prometheus_logger):
"""
diff --git a/tests/proxy_unit_tests/test_e2e_pod_lock_manager.py b/tests/proxy_unit_tests/test_e2e_pod_lock_manager.py
index 652b1838ac..061da8c186 100644
--- a/tests/proxy_unit_tests/test_e2e_pod_lock_manager.py
+++ b/tests/proxy_unit_tests/test_e2e_pod_lock_manager.py
@@ -141,10 +141,12 @@ async def setup_db_connection(prisma_client):
async def test_pod_lock_acquisition_when_no_active_lock():
"""Test if a pod can acquire a lock when no lock is active"""
cronjob_id = str(uuid.uuid4())
- lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
+ lock_manager = PodLockManager(redis_cache=global_redis_cache)
# Attempt to acquire lock
- result = await lock_manager.acquire_lock()
+ result = await lock_manager.acquire_lock(
+ cronjob_id=cronjob_id,
+ )
assert result == True, "Pod should be able to acquire lock when no lock exists"
@@ -161,13 +163,19 @@ async def test_pod_lock_acquisition_after_completion():
"""Test if a new pod can acquire lock after previous pod completes"""
cronjob_id = str(uuid.uuid4())
# First pod acquires and releases lock
- first_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
- await first_lock_manager.acquire_lock()
- await first_lock_manager.release_lock()
+ first_lock_manager = PodLockManager(redis_cache=global_redis_cache)
+ await first_lock_manager.acquire_lock(
+ cronjob_id=cronjob_id,
+ )
+ await first_lock_manager.release_lock(
+ cronjob_id=cronjob_id,
+ )
# Second pod attempts to acquire lock
- second_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
- result = await second_lock_manager.acquire_lock()
+ second_lock_manager = PodLockManager(redis_cache=global_redis_cache)
+ result = await second_lock_manager.acquire_lock(
+ cronjob_id=cronjob_id,
+ )
assert result == True, "Second pod should acquire lock after first pod releases it"
@@ -182,15 +190,21 @@ async def test_pod_lock_acquisition_after_expiry():
"""Test if a new pod can acquire lock after previous pod's lock expires"""
cronjob_id = str(uuid.uuid4())
# First pod acquires lock
- first_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
- await first_lock_manager.acquire_lock()
+ first_lock_manager = PodLockManager(redis_cache=global_redis_cache)
+ await first_lock_manager.acquire_lock(
+ cronjob_id=cronjob_id,
+ )
# release the lock from the first pod
- await first_lock_manager.release_lock()
+ await first_lock_manager.release_lock(
+ cronjob_id=cronjob_id,
+ )
# Second pod attempts to acquire lock
- second_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
- result = await second_lock_manager.acquire_lock()
+ second_lock_manager = PodLockManager(redis_cache=global_redis_cache)
+ result = await second_lock_manager.acquire_lock(
+ cronjob_id=cronjob_id,
+ )
assert (
result == True
@@ -206,11 +220,15 @@ async def test_pod_lock_acquisition_after_expiry():
async def test_pod_lock_release():
"""Test if a pod can successfully release its lock"""
cronjob_id = str(uuid.uuid4())
- lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
+ lock_manager = PodLockManager(redis_cache=global_redis_cache)
# Acquire and then release lock
- await lock_manager.acquire_lock()
- await lock_manager.release_lock()
+ await lock_manager.acquire_lock(
+ cronjob_id=cronjob_id,
+ )
+ await lock_manager.release_lock(
+ cronjob_id=cronjob_id,
+ )
# Verify in redis
lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
@@ -224,15 +242,21 @@ async def test_concurrent_lock_acquisition():
cronjob_id = str(uuid.uuid4())
# Create multiple lock managers simulating different pods
- lock_manager1 = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
- lock_manager2 = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
- lock_manager3 = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
+ lock_manager1 = PodLockManager(redis_cache=global_redis_cache)
+ lock_manager2 = PodLockManager(redis_cache=global_redis_cache)
+ lock_manager3 = PodLockManager(redis_cache=global_redis_cache)
# Try to acquire locks concurrently
results = await asyncio.gather(
- lock_manager1.acquire_lock(),
- lock_manager2.acquire_lock(),
- lock_manager3.acquire_lock(),
+ lock_manager1.acquire_lock(
+ cronjob_id=cronjob_id,
+ ),
+ lock_manager2.acquire_lock(
+ cronjob_id=cronjob_id,
+ ),
+ lock_manager3.acquire_lock(
+ cronjob_id=cronjob_id,
+ ),
)
# Only one should succeed
@@ -254,7 +278,7 @@ async def test_concurrent_lock_acquisition():
async def test_lock_acquisition_with_expired_ttl():
"""Test that a pod can acquire a lock when existing lock has expired TTL"""
cronjob_id = str(uuid.uuid4())
- first_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
+ first_lock_manager = PodLockManager(redis_cache=global_redis_cache)
# First pod acquires lock with a very short TTL to simulate expiration
short_ttl = 1 # 1 second
@@ -269,8 +293,10 @@ async def test_lock_acquisition_with_expired_ttl():
await asyncio.sleep(short_ttl + 0.5) # Wait slightly longer than the TTL
# Second pod tries to acquire without explicit release
- second_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
- result = await second_lock_manager.acquire_lock()
+ second_lock_manager = PodLockManager(redis_cache=global_redis_cache)
+ result = await second_lock_manager.acquire_lock(
+ cronjob_id=cronjob_id,
+ )
assert result == True, "Should acquire lock when existing lock has expired TTL"
@@ -286,7 +312,7 @@ async def test_release_expired_lock():
cronjob_id = str(uuid.uuid4())
# First pod acquires lock with a very short TTL
- first_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
+ first_lock_manager = PodLockManager(redis_cache=global_redis_cache)
short_ttl = 1 # 1 second
lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
await global_redis_cache.async_set_cache(
@@ -299,11 +325,15 @@ async def test_release_expired_lock():
await asyncio.sleep(short_ttl + 0.5) # Wait slightly longer than the TTL
# Second pod acquires the lock
- second_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
- await second_lock_manager.acquire_lock()
+ second_lock_manager = PodLockManager(redis_cache=global_redis_cache)
+ await second_lock_manager.acquire_lock(
+ cronjob_id=cronjob_id,
+ )
# First pod attempts to release its lock
- await first_lock_manager.release_lock()
+ await first_lock_manager.release_lock(
+ cronjob_id=cronjob_id,
+ )
# Verify that second pod's lock is still active
lock_record = await global_redis_cache.async_get_cache(lock_key)