(Prometheus) - emit key budget metrics on startup (#8002)

* add UI_SESSION_TOKEN_TEAM_ID

* add type KeyListResponseObject

* add _list_key_helper

* _initialize_api_key_budget_metrics

* key / budget metrics

* init key budget metrics on startup

* test_initialize_api_key_budget_metrics

* fix linting

* test_list_key_helper

* test_initialize_remaining_budget_metrics_exception_handling
This commit is contained in:
Ishaan Jaff 2025-01-25 10:37:52 -08:00 committed by GitHub
parent d9dcfccdf6
commit 669b4fc955
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 555 additions and 102 deletions

View file

@ -57,7 +57,7 @@ http://localhost:4000/metrics
# <proxy_base_url>/metrics
```
## Virtual Keys, Teams, Internal Users Metrics
## Virtual Keys, Teams, Internal Users
Use this for for tracking per [user, key, team, etc.](virtual_keys)
@ -68,6 +68,42 @@ Use this for for tracking per [user, key, team, etc.](virtual_keys)
| `litellm_input_tokens` | input tokens per `"end_user", "hashed_api_key", "api_key_alias", "requested_model", "team", "team_alias", "user", "model"` |
| `litellm_output_tokens` | output tokens per `"end_user", "hashed_api_key", "api_key_alias", "requested_model", "team", "team_alias", "user", "model"` |
### Team - Budget
| Metric Name | Description |
|----------------------|--------------------------------------|
| `litellm_team_max_budget_metric` | Max Budget for Team Labels: `"team_id", "team_alias"`|
| `litellm_remaining_team_budget_metric` | Remaining Budget for Team (A team created on LiteLLM) Labels: `"team_id", "team_alias"`|
| `litellm_team_budget_remaining_hours_metric` | Hours before the team budget is reset Labels: `"team_id", "team_alias"`|
### Virtual Key - Budget
| Metric Name | Description |
|----------------------|--------------------------------------|
| `litellm_api_key_max_budget_metric` | Max Budget for API Key Labels: `"hashed_api_key", "api_key_alias"`|
| `litellm_remaining_api_key_budget_metric` | Remaining Budget for API Key (A key Created on LiteLLM) Labels: `"hashed_api_key", "api_key_alias"`|
| `litellm_api_key_budget_remaining_hours_metric` | Hours before the API Key budget is reset Labels: `"hashed_api_key", "api_key_alias"`|
### Virtual Key - Rate Limit
| Metric Name | Description |
|----------------------|--------------------------------------|
| `litellm_remaining_api_key_requests_for_model` | Remaining Requests for a LiteLLM virtual API key, only if a model-specific rate limit (rpm) has been set for that virtual key. Labels: `"hashed_api_key", "api_key_alias", "model"`|
| `litellm_remaining_api_key_tokens_for_model` | Remaining Tokens for a LiteLLM virtual API key, only if a model-specific token limit (tpm) has been set for that virtual key. Labels: `"hashed_api_key", "api_key_alias", "model"`|
### Initialize Budget Metrics on Startup
If you want to initialize the key/team budget metrics on startup, you can set the `prometheus_initialize_budget_metrics` to `true` in the `config.yaml`
```yaml
litellm_settings:
callbacks: ["prometheus"]
prometheus_initialize_budget_metrics: true
```
## Proxy Level Tracking Metrics
Use this to track overall LiteLLM Proxy usage.
@ -79,12 +115,11 @@ Use this to track overall LiteLLM Proxy usage.
| `litellm_proxy_failed_requests_metric` | Total number of failed responses from proxy - the client did not get a success response from litellm proxy. Labels: `"end_user", "hashed_api_key", "api_key_alias", "requested_model", "team", "team_alias", "user", "exception_status", "exception_class"` |
| `litellm_proxy_total_requests_metric` | Total number of requests made to the proxy server - track number of client side requests. Labels: `"end_user", "hashed_api_key", "api_key_alias", "requested_model", "team", "team_alias", "user", "status_code"` |
## LLM API / Provider Metrics
## LLM Provider Metrics
Use this for LLM API Error monitoring and tracking remaining rate limits and token limits
### Labels Tracked for LLM API Metrics
### Labels Tracked
| Label | Description |
|-------|-------------|
@ -100,7 +135,7 @@ Use this for LLM API Error monitoring and tracking remaining rate limits and tok
| exception_status | The status of the exception, if any |
| exception_class | The class of the exception, if any |
### Success and Failure Metrics for LLM API
### Success and Failure
| Metric Name | Description |
|----------------------|--------------------------------------|
@ -108,15 +143,14 @@ Use this for LLM API Error monitoring and tracking remaining rate limits and tok
| `litellm_deployment_failure_responses` | Total number of failed LLM API calls for a specific LLM deployment. Labels: `"requested_model", "litellm_model_name", "model_id", "api_base", "api_provider", "hashed_api_key", "api_key_alias", "team", "team_alias", "exception_status", "exception_class"` |
| `litellm_deployment_total_requests` | Total number of LLM API calls for deployment - success + failure. Labels: `"requested_model", "litellm_model_name", "model_id", "api_base", "api_provider", "hashed_api_key", "api_key_alias", "team", "team_alias"` |
### Remaining Requests and Tokens Metrics
### Remaining Requests and Tokens
| Metric Name | Description |
|----------------------|--------------------------------------|
| `litellm_remaining_requests_metric` | Track `x-ratelimit-remaining-requests` returned from LLM API Deployment. Labels: `"model_group", "api_provider", "api_base", "litellm_model_name", "hashed_api_key", "api_key_alias"` |
| `litellm_remaining_tokens` | Track `x-ratelimit-remaining-tokens` return from LLM API Deployment. Labels: `"model_group", "api_provider", "api_base", "litellm_model_name", "hashed_api_key", "api_key_alias"` |
### Deployment State Metrics
### Deployment State
| Metric Name | Description |
|----------------------|--------------------------------------|
| `litellm_deployment_state` | The state of the deployment: 0 = healthy, 1 = partial outage, 2 = complete outage. Labels: `"litellm_model_name", "model_id", "api_base", "api_provider"` |
@ -139,17 +173,6 @@ Use this for LLM API Error monitoring and tracking remaining rate limits and tok
| `litellm_llm_api_latency_metric` | Latency (seconds) for just the LLM API call - tracked for labels "model", "hashed_api_key", "api_key_alias", "team", "team_alias", "requested_model", "end_user", "user" |
| `litellm_llm_api_time_to_first_token_metric` | Time to first token for LLM API call - tracked for labels `model`, `hashed_api_key`, `api_key_alias`, `team`, `team_alias` [Note: only emitted for streaming requests] |
## Virtual Key - Budget, Rate Limit Metrics
Metrics used to track LiteLLM Proxy Budgeting and Rate limiting logic
| Metric Name | Description |
|----------------------|--------------------------------------|
| `litellm_remaining_team_budget_metric` | Remaining Budget for Team (A team created on LiteLLM) Labels: `"team_id", "team_alias"`|
| `litellm_remaining_api_key_budget_metric` | Remaining Budget for API Key (A key Created on LiteLLM) Labels: `"hashed_api_key", "api_key_alias"`|
| `litellm_remaining_api_key_requests_for_model` | Remaining Requests for a LiteLLM virtual API key, only if a model-specific rate limit (rpm) has been set for that virtual key. Labels: `"hashed_api_key", "api_key_alias", "model"`|
| `litellm_remaining_api_key_tokens_for_model` | Remaining Tokens for a LiteLLM virtual API key, only if a model-specific token limit (tpm) has been set for that virtual key. Labels: `"hashed_api_key", "api_key_alias", "model"`|
## [BETA] Custom Metrics
Track custom metrics on prometheus on all events mentioned above.
@ -200,7 +223,6 @@ curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
... "metadata_foo": "hello world" ...
```
## Monitor System Health
To monitor the health of litellm adjacent services (redis / postgres), do:

View file

@ -88,6 +88,7 @@ callbacks: List[
] = []
langfuse_default_tags: Optional[List[str]] = None
langsmith_batch_size: Optional[int] = None
prometheus_initialize_budget_metrics: Optional[bool] = False
argilla_batch_size: Optional[int] = None
datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload
argilla_transformation_object: Optional[Dict[str, Any]] = None

View file

@ -142,3 +142,5 @@ BATCH_STATUS_POLL_INTERVAL_SECONDS = 3600 # 1 hour
BATCH_STATUS_POLL_MAX_ATTEMPTS = 24 # for 24 hours
HEALTH_CHECK_TIMEOUT_SECONDS = 60 # 60 seconds
UI_SESSION_TOKEN_TEAM_ID = "litellm-dashboard"

View file

@ -4,7 +4,7 @@
import asyncio
import sys
from datetime import datetime, timedelta
from typing import List, Optional, cast
from typing import Any, Awaitable, Callable, List, Literal, Optional, Tuple, cast
import litellm
from litellm._logging import print_verbose, verbose_logger
@ -1321,6 +1321,10 @@ class PrometheusLogger(CustomLogger):
Helper to create tasks for initializing metrics that are required on startup - eg. remaining budget metrics
"""
if litellm.prometheus_initialize_budget_metrics is not True:
verbose_logger.debug("Prometheus: skipping budget metrics initialization")
return
try:
if asyncio.get_running_loop():
asyncio.create_task(self._initialize_remaining_budget_metrics())
@ -1329,15 +1333,20 @@ class PrometheusLogger(CustomLogger):
f"No running event loop - skipping budget metrics initialization: {str(e)}"
)
async def _initialize_remaining_budget_metrics(self):
async def _initialize_budget_metrics(
self,
data_fetch_function: Callable[..., Awaitable[Tuple[List[Any], Optional[int]]]],
set_metrics_function: Callable[[List[Any]], Awaitable[None]],
data_type: Literal["teams", "keys"],
):
"""
Initialize remaining budget metrics for all teams to avoid metric discrepancies.
Generic method to initialize budget metrics for teams or API keys.
Runs when prometheus logger starts up.
Args:
data_fetch_function: Function to fetch data with pagination.
set_metrics_function: Function to set metrics for the fetched data.
data_type: String representing the type of data ("teams" or "keys") for logging purposes.
"""
from litellm.proxy.management_endpoints.team_endpoints import (
get_paginated_teams,
)
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
@ -1346,28 +1355,120 @@ class PrometheusLogger(CustomLogger):
try:
page = 1
page_size = 50
teams, total_count = await get_paginated_teams(
prisma_client=prisma_client, page_size=page_size, page=page
data, total_count = await data_fetch_function(
page_size=page_size, page=page
)
if total_count is None:
total_count = len(data)
# Calculate total pages needed
total_pages = (total_count + page_size - 1) // page_size
# Set metrics for first page of teams
await self._set_team_list_budget_metrics(teams)
# Set metrics for first page of data
await set_metrics_function(data)
# Get and set metrics for remaining pages
for page in range(2, total_pages + 1):
teams, _ = await get_paginated_teams(
prisma_client=prisma_client, page_size=page_size, page=page
)
await self._set_team_list_budget_metrics(teams)
data, _ = await data_fetch_function(page_size=page_size, page=page)
await set_metrics_function(data)
except Exception as e:
verbose_logger.exception(
f"Error initializing team budget metrics: {str(e)}"
f"Error initializing {data_type} budget metrics: {str(e)}"
)
async def _initialize_team_budget_metrics(self):
"""
Initialize team budget metrics by reusing the generic pagination logic.
"""
from litellm.proxy.management_endpoints.team_endpoints import (
get_paginated_teams,
)
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
verbose_logger.debug(
"Prometheus: skipping team metrics initialization, DB not initialized"
)
return
async def fetch_teams(
page_size: int, page: int
) -> Tuple[List[LiteLLM_TeamTable], Optional[int]]:
teams, total_count = await get_paginated_teams(
prisma_client=prisma_client, page_size=page_size, page=page
)
if total_count is None:
total_count = len(teams)
return teams, total_count
await self._initialize_budget_metrics(
data_fetch_function=fetch_teams,
set_metrics_function=self._set_team_list_budget_metrics,
data_type="teams",
)
async def _initialize_api_key_budget_metrics(self):
"""
Initialize API key budget metrics by reusing the generic pagination logic.
"""
from typing import Union
from litellm.constants import UI_SESSION_TOKEN_TEAM_ID
from litellm.proxy.management_endpoints.key_management_endpoints import (
_list_key_helper,
)
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
verbose_logger.debug(
"Prometheus: skipping key metrics initialization, DB not initialized"
)
return
async def fetch_keys(
page_size: int, page: int
) -> Tuple[List[Union[str, UserAPIKeyAuth]], Optional[int]]:
key_list_response = await _list_key_helper(
prisma_client=prisma_client,
page=page,
size=page_size,
user_id=None,
team_id=None,
key_alias=None,
exclude_team_id=UI_SESSION_TOKEN_TEAM_ID,
return_full_object=True,
)
keys = key_list_response.get("keys", [])
total_count = key_list_response.get("total_count")
if total_count is None:
total_count = len(keys)
return keys, total_count
await self._initialize_budget_metrics(
data_fetch_function=fetch_keys,
set_metrics_function=self._set_key_list_budget_metrics,
data_type="keys",
)
async def _initialize_remaining_budget_metrics(self):
"""
Initialize remaining budget metrics for all teams to avoid metric discrepancies.
Runs when prometheus logger starts up.
"""
await self._initialize_team_budget_metrics()
await self._initialize_api_key_budget_metrics()
async def _set_key_list_budget_metrics(
self, keys: List[Union[str, UserAPIKeyAuth]]
):
"""Helper function to set budget metrics for a list of keys"""
for key in keys:
if isinstance(key, UserAPIKeyAuth):
self._set_key_budget_metrics(key)
async def _set_team_list_budget_metrics(self, teams: List[LiteLLM_TeamTable]):
"""Helper function to set budget metrics for a list of teams"""
for team in teams:
@ -1431,7 +1532,7 @@ class PrometheusLogger(CustomLogger):
user_api_key_cache=user_api_key_cache,
)
except Exception as e:
verbose_logger.exception(
verbose_logger.debug(
f"[Non-Blocking] Prometheus: Error getting team info: {str(e)}"
)
return team_object
@ -1487,7 +1588,8 @@ class PrometheusLogger(CustomLogger):
- Budget Reset At
"""
self.litellm_remaining_api_key_budget_metric.labels(
user_api_key_dict.token, user_api_key_dict.key_alias
user_api_key_dict.token,
user_api_key_dict.key_alias or "",
).set(
self._safe_get_remaining_budget(
max_budget=user_api_key_dict.max_budget,
@ -1558,7 +1660,7 @@ class PrometheusLogger(CustomLogger):
if key_object:
user_api_key_dict.budget_reset_at = key_object.budget_reset_at
except Exception as e:
verbose_logger.exception(
verbose_logger.debug(
f"[Non-Blocking] Prometheus: Error getting key info: {str(e)}"
)

View file

@ -2156,6 +2156,13 @@ class TeamListResponseObject(LiteLLM_TeamTable):
keys: List # list of keys that belong to the team
class KeyListResponseObject(TypedDict, total=False):
keys: List[Union[str, UserAPIKeyAuth]]
total_count: Optional[int]
current_page: Optional[int]
total_pages: Optional[int]
class CurrentItemRateLimit(TypedDict):
current_requests: int
current_tpm: int

View file

@ -1651,18 +1651,21 @@ async def list_keys(
user_id: Optional[str] = Query(None, description="Filter keys by user ID"),
team_id: Optional[str] = Query(None, description="Filter keys by team ID"),
key_alias: Optional[str] = Query(None, description="Filter keys by key alias"),
):
) -> KeyListResponseObject:
"""
List all keys for a given user or team.
Returns:
{
"keys": List[str],
"total_count": int,
"current_page": int,
"total_pages": int,
}
"""
try:
import logging
from litellm.proxy.proxy_server import prisma_client
logging.debug("Entering list_keys function")
if prisma_client is None:
logging.error("Database not connected")
raise Exception("Database not connected")
# Check for unsupported parameters
supported_params = {"page", "size", "user_id", "team_id", "key_alias"}
unsupported_params = set(request.query_params.keys()) - supported_params
@ -1674,56 +1677,22 @@ async def list_keys(
code=status.HTTP_400_BAD_REQUEST,
)
# Prepare filter conditions
where = {}
if user_id and isinstance(user_id, str):
where["user_id"] = user_id
if team_id and isinstance(team_id, str):
where["team_id"] = team_id
if key_alias and isinstance(key_alias, str):
where["key_alias"] = key_alias
verbose_proxy_logger.debug("Entering list_keys function")
logging.debug(f"Filter conditions: {where}")
if prisma_client is None:
verbose_proxy_logger.error("Database not connected")
raise Exception("Database not connected")
# Calculate skip for pagination
skip = (page - 1) * size
logging.debug(f"Pagination: skip={skip}, take={size}")
# Fetch keys with pagination
keys = await prisma_client.db.litellm_verificationtoken.find_many(
where=where, # type: ignore
skip=skip, # type: ignore
take=size, # type: ignore
response = await _list_key_helper(
prisma_client=prisma_client,
page=page,
size=size,
user_id=user_id,
team_id=team_id,
key_alias=key_alias,
)
logging.debug(f"Fetched {len(keys)} keys")
# Get total count of keys
total_count = await prisma_client.db.litellm_verificationtoken.count(
where=where # type: ignore
)
logging.debug(f"Total count of keys: {total_count}")
# Calculate total pages
total_pages = -(-total_count // size) # Ceiling division
# Prepare response
key_list = []
for key in keys:
key_dict = key.dict()
_token = key_dict.get("token")
key_list.append(_token)
response = {
"keys": key_list,
"total_count": total_count,
"current_page": page,
"total_pages": total_pages,
}
logging.debug("Successfully prepared response")
verbose_proxy_logger.debug("Successfully prepared response")
return response
@ -1745,6 +1714,91 @@ async def list_keys(
)
async def _list_key_helper(
prisma_client: PrismaClient,
page: int,
size: int,
user_id: Optional[str],
team_id: Optional[str],
key_alias: Optional[str],
exclude_team_id: Optional[str] = None,
return_full_object: bool = False,
) -> KeyListResponseObject:
"""
Helper function to list keys
Args:
page: int
size: int
user_id: Optional[str]
team_id: Optional[str]
key_alias: Optional[str]
exclude_team_id: Optional[str] # exclude a specific team_id
return_full_object: bool # when true, will return UserAPIKeyAuth objects instead of just the token
Returns:
KeyListResponseObject
{
"keys": List[str] or List[UserAPIKeyAuth], # Updated to reflect possible return types
"total_count": int,
"current_page": int,
"total_pages": int,
}
"""
# Prepare filter conditions
where: Dict[str, Union[str, Dict[str, str]]] = {}
if user_id and isinstance(user_id, str):
where["user_id"] = user_id
if team_id and isinstance(team_id, str):
where["team_id"] = team_id
if key_alias and isinstance(key_alias, str):
where["key_alias"] = key_alias
if exclude_team_id and isinstance(exclude_team_id, str):
where["team_id"] = {"not": exclude_team_id}
verbose_proxy_logger.debug(f"Filter conditions: {where}")
# Calculate skip for pagination
skip = (page - 1) * size
verbose_proxy_logger.debug(f"Pagination: skip={skip}, take={size}")
# Fetch keys with pagination
keys = await prisma_client.db.litellm_verificationtoken.find_many(
where=where, # type: ignore
skip=skip, # type: ignore
take=size, # type: ignore
)
verbose_proxy_logger.debug(f"Fetched {len(keys)} keys")
# Get total count of keys
total_count = await prisma_client.db.litellm_verificationtoken.count(
where=where # type: ignore
)
verbose_proxy_logger.debug(f"Total count of keys: {total_count}")
# Calculate total pages
total_pages = -(-total_count // size) # Ceiling division
# Prepare response
key_list: List[Union[str, UserAPIKeyAuth]] = []
for key in keys:
if return_full_object is True:
key_list.append(UserAPIKeyAuth(**key.dict())) # Return full key object
else:
_token = key.dict().get("token")
key_list.append(_token) # Return only the token
return KeyListResponseObject(
keys=key_list,
total_count=total_count,
current_page=page,
total_pages=total_pages,
)
@router.post(
"/key/block", tags=["key management"], dependencies=[Depends(user_api_key_auth)]
)

View file

@ -15,8 +15,8 @@ model_list:
litellm_settings:
callbacks: ["gcs_pubsub"]
callbacks: ["prometheus"]
prometheus_initialize_budget_metrics: true
guardrails:
- guardrail_name: "bedrock-pre-guard"
litellm_params:

View file

@ -981,6 +981,7 @@ async def test_initialize_remaining_budget_metrics(prometheus_logger):
"""
Test that _initialize_remaining_budget_metrics correctly sets budget metrics for all teams
"""
litellm.prometheus_initialize_budget_metrics = True
# Mock the prisma client and get_paginated_teams function
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, patch(
"litellm.proxy.management_endpoints.team_endpoints.get_paginated_teams"
@ -1076,30 +1077,41 @@ async def test_initialize_remaining_budget_metrics_exception_handling(
"""
Test that _initialize_remaining_budget_metrics properly handles exceptions
"""
litellm.prometheus_initialize_budget_metrics = True
# Mock the prisma client and get_paginated_teams function to raise an exception
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, patch(
"litellm.proxy.management_endpoints.team_endpoints.get_paginated_teams"
) as mock_get_teams:
) as mock_get_teams, patch(
"litellm.proxy.management_endpoints.key_management_endpoints._list_key_helper"
) as mock_list_keys:
# Make get_paginated_teams raise an exception
mock_get_teams.side_effect = Exception("Database error")
mock_list_keys.side_effect = Exception("Key listing error")
# Mock the Prometheus metric
# Mock the Prometheus metrics
prometheus_logger.litellm_remaining_team_budget_metric = MagicMock()
prometheus_logger.litellm_remaining_api_key_budget_metric = MagicMock()
# Mock the logger to capture the error
with patch("litellm._logging.verbose_logger.exception") as mock_logger:
# Call the function
await prometheus_logger._initialize_remaining_budget_metrics()
# Verify the error was logged
mock_logger.assert_called_once()
# Verify both errors were logged
assert mock_logger.call_count == 2
assert (
"Error initializing team budget metrics" in mock_logger.call_args[0][0]
"Error initializing teams budget metrics"
in mock_logger.call_args_list[0][0][0]
)
assert (
"Error initializing keys budget metrics"
in mock_logger.call_args_list[1][0][0]
)
# Verify the metric was never called
# Verify the metrics were never called
prometheus_logger.litellm_remaining_team_budget_metric.assert_not_called()
prometheus_logger.litellm_remaining_api_key_budget_metric.assert_not_called()
def test_initialize_prometheus_startup_metrics_no_loop(prometheus_logger):
@ -1107,6 +1119,7 @@ def test_initialize_prometheus_startup_metrics_no_loop(prometheus_logger):
Test that _initialize_prometheus_startup_metrics handles case when no event loop exists
"""
# Mock asyncio.get_running_loop to raise RuntimeError
litellm.prometheus_initialize_budget_metrics = True
with patch(
"asyncio.get_running_loop", side_effect=RuntimeError("No running event loop")
), patch("litellm._logging.verbose_logger.exception") as mock_logger:
@ -1117,3 +1130,109 @@ def test_initialize_prometheus_startup_metrics_no_loop(prometheus_logger):
# Verify the error was logged
mock_logger.assert_called_once()
assert "No running event loop" in mock_logger.call_args[0][0]
@pytest.mark.asyncio(scope="session")
async def test_initialize_api_key_budget_metrics(prometheus_logger):
"""
Test that _initialize_api_key_budget_metrics correctly sets budget metrics for all API keys
"""
litellm.prometheus_initialize_budget_metrics = True
# Mock the prisma client and _list_key_helper function
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, patch(
"litellm.proxy.management_endpoints.key_management_endpoints._list_key_helper"
) as mock_list_keys:
# Create mock key data with proper datetime objects for budget_reset_at
future_reset = datetime.now() + timedelta(hours=24) # Reset 24 hours from now
key1 = UserAPIKeyAuth(
api_key="key1_hash",
key_alias="alias1",
team_id="team1",
max_budget=100,
spend=30,
budget_reset_at=future_reset,
)
key1.token = "key1_hash"
key2 = UserAPIKeyAuth(
api_key="key2_hash",
key_alias="alias2",
team_id="team2",
max_budget=200,
spend=50,
budget_reset_at=future_reset,
)
key2.token = "key2_hash"
key3 = UserAPIKeyAuth(
api_key="key3_hash",
key_alias=None,
team_id="team3",
max_budget=300,
spend=100,
budget_reset_at=future_reset,
)
key3.token = "key3_hash"
mock_keys = [
key1,
key2,
key3,
]
# Mock _list_key_helper to return our test data
mock_list_keys.return_value = {"keys": mock_keys, "total_count": len(mock_keys)}
# Mock the Prometheus metrics
prometheus_logger.litellm_remaining_api_key_budget_metric = MagicMock()
prometheus_logger.litellm_api_key_budget_remaining_hours_metric = MagicMock()
prometheus_logger.litellm_api_key_max_budget_metric = MagicMock()
# Call the function
await prometheus_logger._initialize_api_key_budget_metrics()
# Verify the remaining budget metric was set correctly for each key
expected_budget_calls = [
call.labels("key1_hash", "alias1").set(70), # 100 - 30
call.labels("key2_hash", "alias2").set(150), # 200 - 50
call.labels("key3_hash", "").set(200), # 300 - 100
]
prometheus_logger.litellm_remaining_api_key_budget_metric.assert_has_calls(
expected_budget_calls, any_order=True
)
# Get all the calls made to the hours metric
hours_calls = (
prometheus_logger.litellm_api_key_budget_remaining_hours_metric.mock_calls
)
# Verify the structure and approximate values of the hours calls
assert len(hours_calls) == 6 # 3 keys * 2 calls each (labels + set)
# Helper function to extract hours value from call
def get_hours_from_call(call_obj):
if "set" in str(call_obj):
return call_obj[1][0] # Extract the hours value
return None
# Verify each key's hours are approximately 24 (within reasonable bounds)
hours_values = [
get_hours_from_call(call)
for call in hours_calls
if get_hours_from_call(call) is not None
]
for hours in hours_values:
assert (
23.9 <= hours <= 24.0
), f"Hours value {hours} not within expected range"
# Verify max budget metric was set correctly for each key
expected_max_budget_calls = [
call.labels("key1_hash", "alias1").set(100),
call.labels("key2_hash", "alias2").set(200),
call.labels("key3_hash", "").set(300),
]
prometheus_logger.litellm_api_key_max_budget_metric.assert_has_calls(
expected_max_budget_calls, any_order=True
)

View file

@ -842,3 +842,149 @@ async def test_key_update_with_model_specific_params(prisma_client):
"token": token_hash,
}
await update_key_fn(request=request, data=UpdateKeyRequest(**args))
@pytest.mark.asyncio
async def test_list_key_helper(prisma_client):
"""
Test _list_key_helper function with various scenarios:
1. Basic pagination
2. Filtering by user_id
3. Filtering by team_id
4. Filtering by key_alias
5. Return full object vs token only
"""
from litellm.proxy.management_endpoints.key_management_endpoints import (
_list_key_helper,
)
# Setup - create multiple test keys
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
await litellm.proxy.proxy_server.prisma_client.connect()
# Create test data
test_user_id = f"test_user_{uuid.uuid4()}"
test_team_id = f"test_team_{uuid.uuid4()}"
test_key_alias = f"test_alias_{uuid.uuid4()}"
# Create test data with clear patterns
test_keys = []
# 1. Create 2 keys for test user + test team
for i in range(2):
key = await generate_key_fn(
data=GenerateKeyRequest(
user_id=test_user_id,
team_id=test_team_id,
key_alias=f"team_key_{uuid.uuid4()}", # Make unique with UUID
),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="admin",
),
)
test_keys.append(key)
# 2. Create 1 key for test user (no team)
key = await generate_key_fn(
data=GenerateKeyRequest(
user_id=test_user_id,
key_alias=test_key_alias, # Already unique from earlier UUID generation
),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="admin",
),
)
test_keys.append(key)
# 3. Create 2 keys for other users
for i in range(2):
key = await generate_key_fn(
data=GenerateKeyRequest(
user_id=f"other_user_{i}",
key_alias=f"other_key_{uuid.uuid4()}", # Make unique with UUID
),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="admin",
),
)
test_keys.append(key)
# Test 1: Basic pagination
result = await _list_key_helper(
prisma_client=prisma_client,
page=1,
size=2,
user_id=None,
team_id=None,
key_alias=None,
)
assert len(result["keys"]) == 2, "Should return exactly 2 keys"
assert result["total_count"] >= 5, "Should have at least 5 total keys"
assert result["current_page"] == 1
assert isinstance(result["keys"][0], str), "Should return token strings by default"
# Test 2: Filter by user_id
result = await _list_key_helper(
prisma_client=prisma_client,
page=1,
size=10,
user_id=test_user_id,
team_id=None,
key_alias=None,
)
assert len(result["keys"]) == 3, "Should return exactly 3 keys for test user"
# Test 3: Filter by team_id
result = await _list_key_helper(
prisma_client=prisma_client,
page=1,
size=10,
user_id=None,
team_id=test_team_id,
key_alias=None,
)
assert len(result["keys"]) == 2, "Should return exactly 2 keys for test team"
# Test 4: Filter by key_alias
result = await _list_key_helper(
prisma_client=prisma_client,
page=1,
size=10,
user_id=None,
team_id=None,
key_alias=test_key_alias,
)
assert len(result["keys"]) == 1, "Should return exactly 1 key with test alias"
# Test 5: Return full object
result = await _list_key_helper(
prisma_client=prisma_client,
page=1,
size=10,
user_id=test_user_id,
team_id=None,
key_alias=None,
return_full_object=True,
)
assert all(
isinstance(key, UserAPIKeyAuth) for key in result["keys"]
), "Should return UserAPIKeyAuth objects"
assert len(result["keys"]) == 3, "Should return exactly 3 keys for test user"
# Clean up test keys
for key in test_keys:
await delete_key_fn(
data=KeyRequest(keys=[key.key]),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="admin",
),
)