mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
105 lines
3.5 KiB
Python
105 lines
3.5 KiB
Python
import json
|
|
import os
|
|
import sys
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
|
from litellm.proxy.litellm_pre_call_utils import (
|
|
_get_enforced_params,
|
|
check_if_token_is_service_account,
|
|
)
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../../..")
|
|
) # Adds the parent directory to the system path
|
|
|
|
|
|
def test_check_if_token_is_service_account():
|
|
"""
|
|
Test that only keys with `service_account_id` in metadata are considered service accounts
|
|
"""
|
|
# Test case 1: Service account token
|
|
service_account_token = UserAPIKeyAuth(
|
|
api_key="test-key", metadata={"service_account_id": "test-service-account"}
|
|
)
|
|
assert check_if_token_is_service_account(service_account_token) == True
|
|
|
|
# Test case 2: Regular user token
|
|
regular_token = UserAPIKeyAuth(api_key="test-key", metadata={})
|
|
assert check_if_token_is_service_account(regular_token) == False
|
|
|
|
# Test case 3: Token with other metadata
|
|
other_metadata_token = UserAPIKeyAuth(
|
|
api_key="test-key", metadata={"user_id": "test-user"}
|
|
)
|
|
assert check_if_token_is_service_account(other_metadata_token) == False
|
|
|
|
|
|
def test_get_enforced_params_for_service_account_settings():
|
|
"""
|
|
Test that service account enforced params are only added to service account keys
|
|
"""
|
|
service_account_token = UserAPIKeyAuth(
|
|
api_key="test-key", metadata={"service_account_id": "test-service-account"}
|
|
)
|
|
general_settings_with_service_account_settings = {
|
|
"service_account_settings": {"enforced_params": ["metadata.service"]},
|
|
}
|
|
result = _get_enforced_params(
|
|
general_settings=general_settings_with_service_account_settings,
|
|
user_api_key_dict=service_account_token,
|
|
)
|
|
assert result == ["metadata.service"]
|
|
|
|
regular_token = UserAPIKeyAuth(
|
|
api_key="test-key", metadata={"enforced_params": ["user"]}
|
|
)
|
|
result = _get_enforced_params(
|
|
general_settings=general_settings_with_service_account_settings,
|
|
user_api_key_dict=regular_token,
|
|
)
|
|
assert result == ["user"]
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"general_settings, user_api_key_dict, expected_enforced_params",
|
|
[
|
|
(
|
|
{"enforced_params": ["param1", "param2"]},
|
|
UserAPIKeyAuth(
|
|
api_key="test_api_key", user_id="test_user_id", org_id="test_org_id"
|
|
),
|
|
["param1", "param2"],
|
|
),
|
|
(
|
|
{"service_account_settings": {"enforced_params": ["param1", "param2"]}},
|
|
UserAPIKeyAuth(
|
|
api_key="test_api_key",
|
|
user_id="test_user_id",
|
|
org_id="test_org_id",
|
|
metadata={"service_account_id": "test_service_account_id"},
|
|
),
|
|
["param1", "param2"],
|
|
),
|
|
(
|
|
{"service_account_settings": {"enforced_params": ["param1", "param2"]}},
|
|
UserAPIKeyAuth(
|
|
api_key="test_api_key",
|
|
metadata={
|
|
"enforced_params": ["param3", "param4"],
|
|
"service_account_id": "test_service_account_id",
|
|
},
|
|
),
|
|
["param1", "param2", "param3", "param4"],
|
|
),
|
|
],
|
|
)
|
|
def test_get_enforced_params(
|
|
general_settings, user_api_key_dict, expected_enforced_params
|
|
):
|
|
from litellm.proxy.litellm_pre_call_utils import _get_enforced_params
|
|
|
|
enforced_params = _get_enforced_params(general_settings, user_api_key_dict)
|
|
assert enforced_params == expected_enforced_params
|