From 669b4fc95579d86ea3d6393ea51a4e3f9352f8c1 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 25 Jan 2025 10:37:52 -0800 Subject: [PATCH] (Prometheus) - emit key budget metrics on startup (#8002) * add UI_SESSION_TOKEN_TEAM_ID * add type KeyListResponseObject * add _list_key_helper * _initialize_api_key_budget_metrics * key / budget metrics * init key budget metrics on startup * test_initialize_api_key_budget_metrics * fix linting * test_list_key_helper * test_initialize_remaining_budget_metrics_exception_handling --- docs/my-website/docs/proxy/prometheus.md | 62 ++++--- litellm/__init__.py | 1 + litellm/constants.py | 2 + litellm/integrations/prometheus.py | 140 +++++++++++++-- litellm/proxy/_types.py | 7 + .../key_management_endpoints.py | 164 ++++++++++++------ litellm/proxy/proxy_config.yaml | 4 +- .../test_prometheus_unit_tests.py | 131 +++++++++++++- .../test_key_management.py | 146 ++++++++++++++++ 9 files changed, 555 insertions(+), 102 deletions(-) diff --git a/docs/my-website/docs/proxy/prometheus.md b/docs/my-website/docs/proxy/prometheus.md index a0e19a006d..8dff527ae5 100644 --- a/docs/my-website/docs/proxy/prometheus.md +++ b/docs/my-website/docs/proxy/prometheus.md @@ -57,7 +57,7 @@ http://localhost:4000/metrics # /metrics ``` -## Virtual Keys, Teams, Internal Users Metrics +## Virtual Keys, Teams, Internal Users Use this for for tracking per [user, key, team, etc.](virtual_keys) @@ -68,6 +68,42 @@ Use this for for tracking per [user, key, team, etc.](virtual_keys) | `litellm_input_tokens` | input tokens per `"end_user", "hashed_api_key", "api_key_alias", "requested_model", "team", "team_alias", "user", "model"` | | `litellm_output_tokens` | output tokens per `"end_user", "hashed_api_key", "api_key_alias", "requested_model", "team", "team_alias", "user", "model"` | +### Team - Budget + + +| Metric Name | Description | +|----------------------|--------------------------------------| +| `litellm_team_max_budget_metric` | Max Budget for Team Labels: `"team_id", "team_alias"`| +| `litellm_remaining_team_budget_metric` | Remaining Budget for Team (A team created on LiteLLM) Labels: `"team_id", "team_alias"`| +| `litellm_team_budget_remaining_hours_metric` | Hours before the team budget is reset Labels: `"team_id", "team_alias"`| + +### Virtual Key - Budget + +| Metric Name | Description | +|----------------------|--------------------------------------| +| `litellm_api_key_max_budget_metric` | Max Budget for API Key Labels: `"hashed_api_key", "api_key_alias"`| +| `litellm_remaining_api_key_budget_metric` | Remaining Budget for API Key (A key Created on LiteLLM) Labels: `"hashed_api_key", "api_key_alias"`| +| `litellm_api_key_budget_remaining_hours_metric` | Hours before the API Key budget is reset Labels: `"hashed_api_key", "api_key_alias"`| + +### Virtual Key - Rate Limit + +| Metric Name | Description | +|----------------------|--------------------------------------| +| `litellm_remaining_api_key_requests_for_model` | Remaining Requests for a LiteLLM virtual API key, only if a model-specific rate limit (rpm) has been set for that virtual key. Labels: `"hashed_api_key", "api_key_alias", "model"`| +| `litellm_remaining_api_key_tokens_for_model` | Remaining Tokens for a LiteLLM virtual API key, only if a model-specific token limit (tpm) has been set for that virtual key. Labels: `"hashed_api_key", "api_key_alias", "model"`| + + +### Initialize Budget Metrics on Startup + +If you want to initialize the key/team budget metrics on startup, you can set the `prometheus_initialize_budget_metrics` to `true` in the `config.yaml` + +```yaml +litellm_settings: + callbacks: ["prometheus"] + prometheus_initialize_budget_metrics: true +``` + + ## Proxy Level Tracking Metrics Use this to track overall LiteLLM Proxy usage. @@ -79,12 +115,11 @@ Use this to track overall LiteLLM Proxy usage. | `litellm_proxy_failed_requests_metric` | Total number of failed responses from proxy - the client did not get a success response from litellm proxy. Labels: `"end_user", "hashed_api_key", "api_key_alias", "requested_model", "team", "team_alias", "user", "exception_status", "exception_class"` | | `litellm_proxy_total_requests_metric` | Total number of requests made to the proxy server - track number of client side requests. Labels: `"end_user", "hashed_api_key", "api_key_alias", "requested_model", "team", "team_alias", "user", "status_code"` | -## LLM API / Provider Metrics +## LLM Provider Metrics Use this for LLM API Error monitoring and tracking remaining rate limits and token limits -### Labels Tracked for LLM API Metrics - +### Labels Tracked | Label | Description | |-------|-------------| @@ -100,7 +135,7 @@ Use this for LLM API Error monitoring and tracking remaining rate limits and tok | exception_status | The status of the exception, if any | | exception_class | The class of the exception, if any | -### Success and Failure Metrics for LLM API +### Success and Failure | Metric Name | Description | |----------------------|--------------------------------------| @@ -108,15 +143,14 @@ Use this for LLM API Error monitoring and tracking remaining rate limits and tok | `litellm_deployment_failure_responses` | Total number of failed LLM API calls for a specific LLM deployment. Labels: `"requested_model", "litellm_model_name", "model_id", "api_base", "api_provider", "hashed_api_key", "api_key_alias", "team", "team_alias", "exception_status", "exception_class"` | | `litellm_deployment_total_requests` | Total number of LLM API calls for deployment - success + failure. Labels: `"requested_model", "litellm_model_name", "model_id", "api_base", "api_provider", "hashed_api_key", "api_key_alias", "team", "team_alias"` | -### Remaining Requests and Tokens Metrics +### Remaining Requests and Tokens | Metric Name | Description | |----------------------|--------------------------------------| | `litellm_remaining_requests_metric` | Track `x-ratelimit-remaining-requests` returned from LLM API Deployment. Labels: `"model_group", "api_provider", "api_base", "litellm_model_name", "hashed_api_key", "api_key_alias"` | | `litellm_remaining_tokens` | Track `x-ratelimit-remaining-tokens` return from LLM API Deployment. Labels: `"model_group", "api_provider", "api_base", "litellm_model_name", "hashed_api_key", "api_key_alias"` | -### Deployment State Metrics - +### Deployment State | Metric Name | Description | |----------------------|--------------------------------------| | `litellm_deployment_state` | The state of the deployment: 0 = healthy, 1 = partial outage, 2 = complete outage. Labels: `"litellm_model_name", "model_id", "api_base", "api_provider"` | @@ -139,17 +173,6 @@ Use this for LLM API Error monitoring and tracking remaining rate limits and tok | `litellm_llm_api_latency_metric` | Latency (seconds) for just the LLM API call - tracked for labels "model", "hashed_api_key", "api_key_alias", "team", "team_alias", "requested_model", "end_user", "user" | | `litellm_llm_api_time_to_first_token_metric` | Time to first token for LLM API call - tracked for labels `model`, `hashed_api_key`, `api_key_alias`, `team`, `team_alias` [Note: only emitted for streaming requests] | -## Virtual Key - Budget, Rate Limit Metrics - -Metrics used to track LiteLLM Proxy Budgeting and Rate limiting logic - -| Metric Name | Description | -|----------------------|--------------------------------------| -| `litellm_remaining_team_budget_metric` | Remaining Budget for Team (A team created on LiteLLM) Labels: `"team_id", "team_alias"`| -| `litellm_remaining_api_key_budget_metric` | Remaining Budget for API Key (A key Created on LiteLLM) Labels: `"hashed_api_key", "api_key_alias"`| -| `litellm_remaining_api_key_requests_for_model` | Remaining Requests for a LiteLLM virtual API key, only if a model-specific rate limit (rpm) has been set for that virtual key. Labels: `"hashed_api_key", "api_key_alias", "model"`| -| `litellm_remaining_api_key_tokens_for_model` | Remaining Tokens for a LiteLLM virtual API key, only if a model-specific token limit (tpm) has been set for that virtual key. Labels: `"hashed_api_key", "api_key_alias", "model"`| - ## [BETA] Custom Metrics Track custom metrics on prometheus on all events mentioned above. @@ -200,7 +223,6 @@ curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \ ... "metadata_foo": "hello world" ... ``` - ## Monitor System Health To monitor the health of litellm adjacent services (redis / postgres), do: diff --git a/litellm/__init__.py b/litellm/__init__.py index c0c2ee45be..814e04d741 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -88,6 +88,7 @@ callbacks: List[ ] = [] langfuse_default_tags: Optional[List[str]] = None langsmith_batch_size: Optional[int] = None +prometheus_initialize_budget_metrics: Optional[bool] = False argilla_batch_size: Optional[int] = None datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload argilla_transformation_object: Optional[Dict[str, Any]] = None diff --git a/litellm/constants.py b/litellm/constants.py index dff574f0f6..0a3b4ee4c7 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -142,3 +142,5 @@ BATCH_STATUS_POLL_INTERVAL_SECONDS = 3600 # 1 hour BATCH_STATUS_POLL_MAX_ATTEMPTS = 24 # for 24 hours HEALTH_CHECK_TIMEOUT_SECONDS = 60 # 60 seconds + +UI_SESSION_TOKEN_TEAM_ID = "litellm-dashboard" diff --git a/litellm/integrations/prometheus.py b/litellm/integrations/prometheus.py index f496dc707c..01e4346afe 100644 --- a/litellm/integrations/prometheus.py +++ b/litellm/integrations/prometheus.py @@ -4,7 +4,7 @@ import asyncio import sys from datetime import datetime, timedelta -from typing import List, Optional, cast +from typing import Any, Awaitable, Callable, List, Literal, Optional, Tuple, cast import litellm from litellm._logging import print_verbose, verbose_logger @@ -1321,6 +1321,10 @@ class PrometheusLogger(CustomLogger): Helper to create tasks for initializing metrics that are required on startup - eg. remaining budget metrics """ + if litellm.prometheus_initialize_budget_metrics is not True: + verbose_logger.debug("Prometheus: skipping budget metrics initialization") + return + try: if asyncio.get_running_loop(): asyncio.create_task(self._initialize_remaining_budget_metrics()) @@ -1329,15 +1333,20 @@ class PrometheusLogger(CustomLogger): f"No running event loop - skipping budget metrics initialization: {str(e)}" ) - async def _initialize_remaining_budget_metrics(self): + async def _initialize_budget_metrics( + self, + data_fetch_function: Callable[..., Awaitable[Tuple[List[Any], Optional[int]]]], + set_metrics_function: Callable[[List[Any]], Awaitable[None]], + data_type: Literal["teams", "keys"], + ): """ - Initialize remaining budget metrics for all teams to avoid metric discrepancies. + Generic method to initialize budget metrics for teams or API keys. - Runs when prometheus logger starts up. + Args: + data_fetch_function: Function to fetch data with pagination. + set_metrics_function: Function to set metrics for the fetched data. + data_type: String representing the type of data ("teams" or "keys") for logging purposes. """ - from litellm.proxy.management_endpoints.team_endpoints import ( - get_paginated_teams, - ) from litellm.proxy.proxy_server import prisma_client if prisma_client is None: @@ -1346,28 +1355,120 @@ class PrometheusLogger(CustomLogger): try: page = 1 page_size = 50 - teams, total_count = await get_paginated_teams( - prisma_client=prisma_client, page_size=page_size, page=page + data, total_count = await data_fetch_function( + page_size=page_size, page=page ) + if total_count is None: + total_count = len(data) + # Calculate total pages needed total_pages = (total_count + page_size - 1) // page_size - # Set metrics for first page of teams - await self._set_team_list_budget_metrics(teams) + # Set metrics for first page of data + await set_metrics_function(data) # Get and set metrics for remaining pages for page in range(2, total_pages + 1): - teams, _ = await get_paginated_teams( - prisma_client=prisma_client, page_size=page_size, page=page - ) - await self._set_team_list_budget_metrics(teams) + data, _ = await data_fetch_function(page_size=page_size, page=page) + await set_metrics_function(data) except Exception as e: verbose_logger.exception( - f"Error initializing team budget metrics: {str(e)}" + f"Error initializing {data_type} budget metrics: {str(e)}" ) + async def _initialize_team_budget_metrics(self): + """ + Initialize team budget metrics by reusing the generic pagination logic. + """ + from litellm.proxy.management_endpoints.team_endpoints import ( + get_paginated_teams, + ) + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + verbose_logger.debug( + "Prometheus: skipping team metrics initialization, DB not initialized" + ) + return + + async def fetch_teams( + page_size: int, page: int + ) -> Tuple[List[LiteLLM_TeamTable], Optional[int]]: + teams, total_count = await get_paginated_teams( + prisma_client=prisma_client, page_size=page_size, page=page + ) + if total_count is None: + total_count = len(teams) + return teams, total_count + + await self._initialize_budget_metrics( + data_fetch_function=fetch_teams, + set_metrics_function=self._set_team_list_budget_metrics, + data_type="teams", + ) + + async def _initialize_api_key_budget_metrics(self): + """ + Initialize API key budget metrics by reusing the generic pagination logic. + """ + from typing import Union + + from litellm.constants import UI_SESSION_TOKEN_TEAM_ID + from litellm.proxy.management_endpoints.key_management_endpoints import ( + _list_key_helper, + ) + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + verbose_logger.debug( + "Prometheus: skipping key metrics initialization, DB not initialized" + ) + return + + async def fetch_keys( + page_size: int, page: int + ) -> Tuple[List[Union[str, UserAPIKeyAuth]], Optional[int]]: + key_list_response = await _list_key_helper( + prisma_client=prisma_client, + page=page, + size=page_size, + user_id=None, + team_id=None, + key_alias=None, + exclude_team_id=UI_SESSION_TOKEN_TEAM_ID, + return_full_object=True, + ) + keys = key_list_response.get("keys", []) + total_count = key_list_response.get("total_count") + if total_count is None: + total_count = len(keys) + return keys, total_count + + await self._initialize_budget_metrics( + data_fetch_function=fetch_keys, + set_metrics_function=self._set_key_list_budget_metrics, + data_type="keys", + ) + + async def _initialize_remaining_budget_metrics(self): + """ + Initialize remaining budget metrics for all teams to avoid metric discrepancies. + + Runs when prometheus logger starts up. + """ + await self._initialize_team_budget_metrics() + await self._initialize_api_key_budget_metrics() + + async def _set_key_list_budget_metrics( + self, keys: List[Union[str, UserAPIKeyAuth]] + ): + """Helper function to set budget metrics for a list of keys""" + for key in keys: + if isinstance(key, UserAPIKeyAuth): + self._set_key_budget_metrics(key) + async def _set_team_list_budget_metrics(self, teams: List[LiteLLM_TeamTable]): """Helper function to set budget metrics for a list of teams""" for team in teams: @@ -1431,7 +1532,7 @@ class PrometheusLogger(CustomLogger): user_api_key_cache=user_api_key_cache, ) except Exception as e: - verbose_logger.exception( + verbose_logger.debug( f"[Non-Blocking] Prometheus: Error getting team info: {str(e)}" ) return team_object @@ -1487,7 +1588,8 @@ class PrometheusLogger(CustomLogger): - Budget Reset At """ self.litellm_remaining_api_key_budget_metric.labels( - user_api_key_dict.token, user_api_key_dict.key_alias + user_api_key_dict.token, + user_api_key_dict.key_alias or "", ).set( self._safe_get_remaining_budget( max_budget=user_api_key_dict.max_budget, @@ -1558,7 +1660,7 @@ class PrometheusLogger(CustomLogger): if key_object: user_api_key_dict.budget_reset_at = key_object.budget_reset_at except Exception as e: - verbose_logger.exception( + verbose_logger.debug( f"[Non-Blocking] Prometheus: Error getting key info: {str(e)}" ) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 8ac0bc019a..270dab7f21 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -2156,6 +2156,13 @@ class TeamListResponseObject(LiteLLM_TeamTable): keys: List # list of keys that belong to the team +class KeyListResponseObject(TypedDict, total=False): + keys: List[Union[str, UserAPIKeyAuth]] + total_count: Optional[int] + current_page: Optional[int] + total_pages: Optional[int] + + class CurrentItemRateLimit(TypedDict): current_requests: int current_tpm: int diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 02d8f49a34..8761e1ac9f 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -1651,18 +1651,21 @@ async def list_keys( user_id: Optional[str] = Query(None, description="Filter keys by user ID"), team_id: Optional[str] = Query(None, description="Filter keys by team ID"), key_alias: Optional[str] = Query(None, description="Filter keys by key alias"), -): +) -> KeyListResponseObject: + """ + List all keys for a given user or team. + + Returns: + { + "keys": List[str], + "total_count": int, + "current_page": int, + "total_pages": int, + } + """ try: - import logging - from litellm.proxy.proxy_server import prisma_client - logging.debug("Entering list_keys function") - - if prisma_client is None: - logging.error("Database not connected") - raise Exception("Database not connected") - # Check for unsupported parameters supported_params = {"page", "size", "user_id", "team_id", "key_alias"} unsupported_params = set(request.query_params.keys()) - supported_params @@ -1674,56 +1677,22 @@ async def list_keys( code=status.HTTP_400_BAD_REQUEST, ) - # Prepare filter conditions - where = {} - if user_id and isinstance(user_id, str): - where["user_id"] = user_id - if team_id and isinstance(team_id, str): - where["team_id"] = team_id - if key_alias and isinstance(key_alias, str): - where["key_alias"] = key_alias + verbose_proxy_logger.debug("Entering list_keys function") - logging.debug(f"Filter conditions: {where}") + if prisma_client is None: + verbose_proxy_logger.error("Database not connected") + raise Exception("Database not connected") - # Calculate skip for pagination - skip = (page - 1) * size - - logging.debug(f"Pagination: skip={skip}, take={size}") - - # Fetch keys with pagination - keys = await prisma_client.db.litellm_verificationtoken.find_many( - where=where, # type: ignore - skip=skip, # type: ignore - take=size, # type: ignore + response = await _list_key_helper( + prisma_client=prisma_client, + page=page, + size=size, + user_id=user_id, + team_id=team_id, + key_alias=key_alias, ) - logging.debug(f"Fetched {len(keys)} keys") - - # Get total count of keys - total_count = await prisma_client.db.litellm_verificationtoken.count( - where=where # type: ignore - ) - - logging.debug(f"Total count of keys: {total_count}") - - # Calculate total pages - total_pages = -(-total_count // size) # Ceiling division - - # Prepare response - key_list = [] - for key in keys: - key_dict = key.dict() - _token = key_dict.get("token") - key_list.append(_token) - - response = { - "keys": key_list, - "total_count": total_count, - "current_page": page, - "total_pages": total_pages, - } - - logging.debug("Successfully prepared response") + verbose_proxy_logger.debug("Successfully prepared response") return response @@ -1745,6 +1714,91 @@ async def list_keys( ) +async def _list_key_helper( + prisma_client: PrismaClient, + page: int, + size: int, + user_id: Optional[str], + team_id: Optional[str], + key_alias: Optional[str], + exclude_team_id: Optional[str] = None, + return_full_object: bool = False, +) -> KeyListResponseObject: + """ + Helper function to list keys + Args: + page: int + size: int + user_id: Optional[str] + team_id: Optional[str] + key_alias: Optional[str] + exclude_team_id: Optional[str] # exclude a specific team_id + return_full_object: bool # when true, will return UserAPIKeyAuth objects instead of just the token + + Returns: + KeyListResponseObject + { + "keys": List[str] or List[UserAPIKeyAuth], # Updated to reflect possible return types + "total_count": int, + "current_page": int, + "total_pages": int, + } + """ + + # Prepare filter conditions + where: Dict[str, Union[str, Dict[str, str]]] = {} + if user_id and isinstance(user_id, str): + where["user_id"] = user_id + if team_id and isinstance(team_id, str): + where["team_id"] = team_id + if key_alias and isinstance(key_alias, str): + where["key_alias"] = key_alias + if exclude_team_id and isinstance(exclude_team_id, str): + where["team_id"] = {"not": exclude_team_id} + + verbose_proxy_logger.debug(f"Filter conditions: {where}") + + # Calculate skip for pagination + skip = (page - 1) * size + + verbose_proxy_logger.debug(f"Pagination: skip={skip}, take={size}") + + # Fetch keys with pagination + keys = await prisma_client.db.litellm_verificationtoken.find_many( + where=where, # type: ignore + skip=skip, # type: ignore + take=size, # type: ignore + ) + + verbose_proxy_logger.debug(f"Fetched {len(keys)} keys") + + # Get total count of keys + total_count = await prisma_client.db.litellm_verificationtoken.count( + where=where # type: ignore + ) + + verbose_proxy_logger.debug(f"Total count of keys: {total_count}") + + # Calculate total pages + total_pages = -(-total_count // size) # Ceiling division + + # Prepare response + key_list: List[Union[str, UserAPIKeyAuth]] = [] + for key in keys: + if return_full_object is True: + key_list.append(UserAPIKeyAuth(**key.dict())) # Return full key object + else: + _token = key.dict().get("token") + key_list.append(_token) # Return only the token + + return KeyListResponseObject( + keys=key_list, + total_count=total_count, + current_page=page, + total_pages=total_pages, + ) + + @router.post( "/key/block", tags=["key management"], dependencies=[Depends(user_api_key_auth)] ) diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 2f21949792..cf93bd679b 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -15,8 +15,8 @@ model_list: litellm_settings: - callbacks: ["gcs_pubsub"] - + callbacks: ["prometheus"] + prometheus_initialize_budget_metrics: true guardrails: - guardrail_name: "bedrock-pre-guard" litellm_params: diff --git a/tests/logging_callback_tests/test_prometheus_unit_tests.py b/tests/logging_callback_tests/test_prometheus_unit_tests.py index 7307050d0f..1b157dd335 100644 --- a/tests/logging_callback_tests/test_prometheus_unit_tests.py +++ b/tests/logging_callback_tests/test_prometheus_unit_tests.py @@ -981,6 +981,7 @@ async def test_initialize_remaining_budget_metrics(prometheus_logger): """ Test that _initialize_remaining_budget_metrics correctly sets budget metrics for all teams """ + litellm.prometheus_initialize_budget_metrics = True # Mock the prisma client and get_paginated_teams function with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, patch( "litellm.proxy.management_endpoints.team_endpoints.get_paginated_teams" @@ -1076,30 +1077,41 @@ async def test_initialize_remaining_budget_metrics_exception_handling( """ Test that _initialize_remaining_budget_metrics properly handles exceptions """ + litellm.prometheus_initialize_budget_metrics = True # Mock the prisma client and get_paginated_teams function to raise an exception with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, patch( "litellm.proxy.management_endpoints.team_endpoints.get_paginated_teams" - ) as mock_get_teams: + ) as mock_get_teams, patch( + "litellm.proxy.management_endpoints.key_management_endpoints._list_key_helper" + ) as mock_list_keys: # Make get_paginated_teams raise an exception mock_get_teams.side_effect = Exception("Database error") + mock_list_keys.side_effect = Exception("Key listing error") - # Mock the Prometheus metric + # Mock the Prometheus metrics prometheus_logger.litellm_remaining_team_budget_metric = MagicMock() + prometheus_logger.litellm_remaining_api_key_budget_metric = MagicMock() # Mock the logger to capture the error with patch("litellm._logging.verbose_logger.exception") as mock_logger: # Call the function await prometheus_logger._initialize_remaining_budget_metrics() - # Verify the error was logged - mock_logger.assert_called_once() + # Verify both errors were logged + assert mock_logger.call_count == 2 assert ( - "Error initializing team budget metrics" in mock_logger.call_args[0][0] + "Error initializing teams budget metrics" + in mock_logger.call_args_list[0][0][0] + ) + assert ( + "Error initializing keys budget metrics" + in mock_logger.call_args_list[1][0][0] ) - # Verify the metric was never called + # Verify the metrics were never called prometheus_logger.litellm_remaining_team_budget_metric.assert_not_called() + prometheus_logger.litellm_remaining_api_key_budget_metric.assert_not_called() def test_initialize_prometheus_startup_metrics_no_loop(prometheus_logger): @@ -1107,6 +1119,7 @@ def test_initialize_prometheus_startup_metrics_no_loop(prometheus_logger): Test that _initialize_prometheus_startup_metrics handles case when no event loop exists """ # Mock asyncio.get_running_loop to raise RuntimeError + litellm.prometheus_initialize_budget_metrics = True with patch( "asyncio.get_running_loop", side_effect=RuntimeError("No running event loop") ), patch("litellm._logging.verbose_logger.exception") as mock_logger: @@ -1117,3 +1130,109 @@ def test_initialize_prometheus_startup_metrics_no_loop(prometheus_logger): # Verify the error was logged mock_logger.assert_called_once() assert "No running event loop" in mock_logger.call_args[0][0] + + +@pytest.mark.asyncio(scope="session") +async def test_initialize_api_key_budget_metrics(prometheus_logger): + """ + Test that _initialize_api_key_budget_metrics correctly sets budget metrics for all API keys + """ + litellm.prometheus_initialize_budget_metrics = True + # Mock the prisma client and _list_key_helper function + with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, patch( + "litellm.proxy.management_endpoints.key_management_endpoints._list_key_helper" + ) as mock_list_keys: + + # Create mock key data with proper datetime objects for budget_reset_at + future_reset = datetime.now() + timedelta(hours=24) # Reset 24 hours from now + key1 = UserAPIKeyAuth( + api_key="key1_hash", + key_alias="alias1", + team_id="team1", + max_budget=100, + spend=30, + budget_reset_at=future_reset, + ) + key1.token = "key1_hash" + key2 = UserAPIKeyAuth( + api_key="key2_hash", + key_alias="alias2", + team_id="team2", + max_budget=200, + spend=50, + budget_reset_at=future_reset, + ) + key2.token = "key2_hash" + + key3 = UserAPIKeyAuth( + api_key="key3_hash", + key_alias=None, + team_id="team3", + max_budget=300, + spend=100, + budget_reset_at=future_reset, + ) + key3.token = "key3_hash" + + mock_keys = [ + key1, + key2, + key3, + ] + + # Mock _list_key_helper to return our test data + mock_list_keys.return_value = {"keys": mock_keys, "total_count": len(mock_keys)} + + # Mock the Prometheus metrics + prometheus_logger.litellm_remaining_api_key_budget_metric = MagicMock() + prometheus_logger.litellm_api_key_budget_remaining_hours_metric = MagicMock() + prometheus_logger.litellm_api_key_max_budget_metric = MagicMock() + + # Call the function + await prometheus_logger._initialize_api_key_budget_metrics() + + # Verify the remaining budget metric was set correctly for each key + expected_budget_calls = [ + call.labels("key1_hash", "alias1").set(70), # 100 - 30 + call.labels("key2_hash", "alias2").set(150), # 200 - 50 + call.labels("key3_hash", "").set(200), # 300 - 100 + ] + + prometheus_logger.litellm_remaining_api_key_budget_metric.assert_has_calls( + expected_budget_calls, any_order=True + ) + + # Get all the calls made to the hours metric + hours_calls = ( + prometheus_logger.litellm_api_key_budget_remaining_hours_metric.mock_calls + ) + + # Verify the structure and approximate values of the hours calls + assert len(hours_calls) == 6 # 3 keys * 2 calls each (labels + set) + + # Helper function to extract hours value from call + def get_hours_from_call(call_obj): + if "set" in str(call_obj): + return call_obj[1][0] # Extract the hours value + return None + + # Verify each key's hours are approximately 24 (within reasonable bounds) + hours_values = [ + get_hours_from_call(call) + for call in hours_calls + if get_hours_from_call(call) is not None + ] + for hours in hours_values: + assert ( + 23.9 <= hours <= 24.0 + ), f"Hours value {hours} not within expected range" + + # Verify max budget metric was set correctly for each key + expected_max_budget_calls = [ + call.labels("key1_hash", "alias1").set(100), + call.labels("key2_hash", "alias2").set(200), + call.labels("key3_hash", "").set(300), + ] + prometheus_logger.litellm_api_key_max_budget_metric.assert_has_calls( + expected_max_budget_calls, any_order=True + ) diff --git a/tests/proxy_admin_ui_tests/test_key_management.py b/tests/proxy_admin_ui_tests/test_key_management.py index 620a650dfd..f443d29715 100644 --- a/tests/proxy_admin_ui_tests/test_key_management.py +++ b/tests/proxy_admin_ui_tests/test_key_management.py @@ -842,3 +842,149 @@ async def test_key_update_with_model_specific_params(prisma_client): "token": token_hash, } await update_key_fn(request=request, data=UpdateKeyRequest(**args)) + + +@pytest.mark.asyncio +async def test_list_key_helper(prisma_client): + """ + Test _list_key_helper function with various scenarios: + 1. Basic pagination + 2. Filtering by user_id + 3. Filtering by team_id + 4. Filtering by key_alias + 5. Return full object vs token only + """ + from litellm.proxy.management_endpoints.key_management_endpoints import ( + _list_key_helper, + ) + + # Setup - create multiple test keys + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + await litellm.proxy.proxy_server.prisma_client.connect() + + # Create test data + test_user_id = f"test_user_{uuid.uuid4()}" + test_team_id = f"test_team_{uuid.uuid4()}" + test_key_alias = f"test_alias_{uuid.uuid4()}" + + # Create test data with clear patterns + test_keys = [] + + # 1. Create 2 keys for test user + test team + for i in range(2): + key = await generate_key_fn( + data=GenerateKeyRequest( + user_id=test_user_id, + team_id=test_team_id, + key_alias=f"team_key_{uuid.uuid4()}", # Make unique with UUID + ), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="admin", + ), + ) + test_keys.append(key) + + # 2. Create 1 key for test user (no team) + key = await generate_key_fn( + data=GenerateKeyRequest( + user_id=test_user_id, + key_alias=test_key_alias, # Already unique from earlier UUID generation + ), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="admin", + ), + ) + test_keys.append(key) + + # 3. Create 2 keys for other users + for i in range(2): + key = await generate_key_fn( + data=GenerateKeyRequest( + user_id=f"other_user_{i}", + key_alias=f"other_key_{uuid.uuid4()}", # Make unique with UUID + ), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="admin", + ), + ) + test_keys.append(key) + + # Test 1: Basic pagination + result = await _list_key_helper( + prisma_client=prisma_client, + page=1, + size=2, + user_id=None, + team_id=None, + key_alias=None, + ) + assert len(result["keys"]) == 2, "Should return exactly 2 keys" + assert result["total_count"] >= 5, "Should have at least 5 total keys" + assert result["current_page"] == 1 + assert isinstance(result["keys"][0], str), "Should return token strings by default" + + # Test 2: Filter by user_id + result = await _list_key_helper( + prisma_client=prisma_client, + page=1, + size=10, + user_id=test_user_id, + team_id=None, + key_alias=None, + ) + assert len(result["keys"]) == 3, "Should return exactly 3 keys for test user" + + # Test 3: Filter by team_id + result = await _list_key_helper( + prisma_client=prisma_client, + page=1, + size=10, + user_id=None, + team_id=test_team_id, + key_alias=None, + ) + assert len(result["keys"]) == 2, "Should return exactly 2 keys for test team" + + # Test 4: Filter by key_alias + result = await _list_key_helper( + prisma_client=prisma_client, + page=1, + size=10, + user_id=None, + team_id=None, + key_alias=test_key_alias, + ) + assert len(result["keys"]) == 1, "Should return exactly 1 key with test alias" + + # Test 5: Return full object + result = await _list_key_helper( + prisma_client=prisma_client, + page=1, + size=10, + user_id=test_user_id, + team_id=None, + key_alias=None, + return_full_object=True, + ) + assert all( + isinstance(key, UserAPIKeyAuth) for key in result["keys"] + ), "Should return UserAPIKeyAuth objects" + assert len(result["keys"]) == 3, "Should return exactly 3 keys for test user" + + # Clean up test keys + for key in test_keys: + await delete_key_fn( + data=KeyRequest(keys=[key.key]), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="admin", + ), + )