Internal User Endpoint - vulnerability fix + response type fix (#8228)

* fix(key_management_endpoints.py): fix vulnerability where a user could update another user's keys

Resolves https://github.com/BerriAI/litellm/issues/8031

* test(key_management_endpoints.py): return consistent 403 forbidden error when modifying key that doesn't belong to user

* fix(internal_user_endpoints.py): return model max budget in internal user create response

Fixes https://github.com/BerriAI/litellm/issues/7047

* test: fix test

* test: update test to handle gemini token counter change

* fix(factory.py): fix bedrock http:// handling

* docs: fix typo in lm_studio.md (#8222)

* test: fix testing

* test: fix test

---------

Co-authored-by: foreign-sub <51928805+foreign-sub@users.noreply.github.com>
This commit is contained in:
Krish Dholakia 2025-02-04 06:41:14 -08:00 committed by GitHub
parent f6bd48a1c5
commit df93debbc7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 240 additions and 28 deletions

View file

@ -787,6 +787,7 @@ class NewUserResponse(GenerateKeyResponse):
] = None ] = None
teams: Optional[list] = None teams: Optional[list] = None
user_alias: Optional[str] = None user_alias: Optional[str] = None
model_max_budget: Optional[dict] = None
class UpdateUserRequest(GenerateRequestBase): class UpdateUserRequest(GenerateRequestBase):

View file

@ -142,7 +142,6 @@ async def new_user(
data_json = data.json() # type: ignore data_json = data.json() # type: ignore
data_json = _update_internal_new_user_params(data_json, data) data_json = _update_internal_new_user_params(data_json, data)
response = await generate_key_helper_fn(request_type="user", **data_json) response = await generate_key_helper_fn(request_type="user", **data_json)
# Admin UI Logic # Admin UI Logic
# Add User to Team and Organization # Add User to Team and Organization
# if team_id passed add this user to the team # if team_id passed add this user to the team
@ -220,6 +219,7 @@ async def new_user(
tpm_limit=response.get("tpm_limit", None), tpm_limit=response.get("tpm_limit", None),
rpm_limit=response.get("rpm_limit", None), rpm_limit=response.get("rpm_limit", None),
budget_duration=response.get("budget_duration", None), budget_duration=response.get("budget_duration", None),
model_max_budget=response.get("model_max_budget", None),
) )

View file

@ -65,7 +65,7 @@ def _get_user_in_team(
return None return None
def _is_allowed_to_create_key( def _is_allowed_to_make_key_request(
user_api_key_dict: UserAPIKeyAuth, user_id: Optional[str], team_id: Optional[str] user_api_key_dict: UserAPIKeyAuth, user_id: Optional[str], team_id: Optional[str]
) -> bool: ) -> bool:
""" """
@ -266,6 +266,40 @@ def key_generation_check(
) )
def common_key_access_checks(
user_api_key_dict: UserAPIKeyAuth,
data: Union[GenerateKeyRequest, UpdateKeyRequest],
llm_router: Optional[Router],
premium_user: bool,
) -> Literal[True]:
"""
Check if user is allowed to make a key request, for this key
"""
try:
_is_allowed_to_make_key_request(
user_api_key_dict=user_api_key_dict,
user_id=data.user_id,
team_id=data.team_id,
)
except AssertionError as e:
raise HTTPException(
status_code=403,
detail=str(e),
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=str(e),
)
_check_model_access_group(
models=data.models,
llm_router=llm_router,
premium_user=premium_user,
)
return True
router = APIRouter() router = APIRouter()
@ -381,25 +415,9 @@ async def generate_key_fn( # noqa: PLR0915
data=data, data=data,
) )
try: common_key_access_checks(
_is_allowed_to_create_key( user_api_key_dict=user_api_key_dict,
user_api_key_dict=user_api_key_dict, data=data,
user_id=data.user_id,
team_id=data.team_id,
)
except AssertionError as e:
raise HTTPException(
status_code=403,
detail=str(e),
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=str(e),
)
_check_model_access_group(
models=data.models,
llm_router=llm_router, llm_router=llm_router,
premium_user=premium_user, premium_user=premium_user,
) )
@ -684,6 +702,8 @@ async def update_key_fn(
``` ```
""" """
from litellm.proxy.proxy_server import ( from litellm.proxy.proxy_server import (
llm_router,
premium_user,
prisma_client, prisma_client,
proxy_logging_obj, proxy_logging_obj,
user_api_key_cache, user_api_key_cache,
@ -692,10 +712,18 @@ async def update_key_fn(
try: try:
data_json: dict = data.model_dump(exclude_unset=True, exclude_none=True) data_json: dict = data.model_dump(exclude_unset=True, exclude_none=True)
key = data_json.pop("key") key = data_json.pop("key")
# get the row from db # get the row from db
if prisma_client is None: if prisma_client is None:
raise Exception("Not connected to DB!") raise Exception("Not connected to DB!")
common_key_access_checks(
user_api_key_dict=user_api_key_dict,
data=data,
llm_router=llm_router,
premium_user=premium_user,
)
existing_key_row = await prisma_client.get_data( existing_key_row = await prisma_client.get_data(
token=data.key, table_name="key", query_type="find_unique" token=data.key, table_name="key", query_type="find_unique"
) )
@ -1412,6 +1440,13 @@ async def delete_verification_tokens(
): ):
await prisma_client.delete_data(tokens=[key.token]) await prisma_client.delete_data(tokens=[key.token])
deleted_tokens.append(key.token) deleted_tokens.append(key.token)
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"error": "You are not authorized to delete this key"
},
)
tasks.append(_delete_key(key)) tasks.append(_delete_key(key))
await asyncio.gather(*tasks) await asyncio.gather(*tasks)

View file

@ -842,7 +842,15 @@ async def test_key_update_with_model_specific_params(prisma_client):
"litellm_budget_table": None, "litellm_budget_table": None,
"token": token_hash, "token": token_hash,
} }
await update_key_fn(request=request, data=UpdateKeyRequest(**args)) await update_key_fn(
request=request,
data=UpdateKeyRequest(**args),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
@pytest.mark.asyncio @pytest.mark.asyncio

View file

@ -1314,6 +1314,11 @@ def test_generate_and_update_key(prisma_client):
budget_duration="1mo", budget_duration="1mo",
max_budget=100, max_budget=100,
), ),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
) )
print("response1=", response1) print("response1=", response1)
@ -1322,6 +1327,11 @@ def test_generate_and_update_key(prisma_client):
response2 = await update_key_fn( response2 = await update_key_fn(
request=Request, request=Request,
data=UpdateKeyRequest(key=generated_key, team_id=_team_2), data=UpdateKeyRequest(key=generated_key, team_id=_team_2),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
) )
print("response2=", response2) print("response2=", response2)
@ -2956,7 +2966,11 @@ async def test_generate_key_with_model_tpm_limit(prisma_client):
_request = Request(scope={"type": "http"}) _request = Request(scope={"type": "http"})
_request._url = URL(url="/update/key") _request._url = URL(url="/update/key")
await update_key_fn(data=request, request=_request) await update_key_fn(
data=request,
request=_request,
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
)
result = await info_key_fn( result = await info_key_fn(
key=generated_key, key=generated_key,
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
@ -3017,7 +3031,11 @@ async def test_generate_key_with_guardrails(prisma_client):
_request = Request(scope={"type": "http"}) _request = Request(scope={"type": "http"})
_request._url = URL(url="/update/key") _request._url = URL(url="/update/key")
await update_key_fn(data=request, request=_request) await update_key_fn(
data=request,
request=_request,
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
)
result = await info_key_fn( result = await info_key_fn(
key=generated_key, key=generated_key,
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
@ -3710,6 +3728,11 @@ async def test_key_alias_uniqueness(prisma_client):
await update_key_fn( await update_key_fn(
data=UpdateKeyRequest(key=key3.key, key_alias=unique_alias), data=UpdateKeyRequest(key=key3.key, key_alias=unique_alias),
request=Request(scope={"type": "http"}), request=Request(scope={"type": "http"}),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
) )
pytest.fail("Should not be able to update a key to use an existing alias") pytest.fail("Should not be able to update a key to use an existing alias")
except Exception as e: except Exception as e:
@ -3719,6 +3742,11 @@ async def test_key_alias_uniqueness(prisma_client):
updated_key = await update_key_fn( updated_key = await update_key_fn(
data=UpdateKeyRequest(key=key1.key, key_alias=unique_alias), data=UpdateKeyRequest(key=key1.key, key_alias=unique_alias),
request=Request(scope={"type": "http"}), request=Request(scope={"type": "http"}),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
) )
assert updated_key is not None assert updated_key is not None

View file

@ -1216,14 +1216,14 @@ def test_litellm_verification_token_view_response_with_budget_table(
) )
def test_is_allowed_to_create_key(): def test_is_allowed_to_make_key_request():
from litellm.proxy._types import LitellmUserRoles from litellm.proxy._types import LitellmUserRoles
from litellm.proxy.management_endpoints.key_management_endpoints import ( from litellm.proxy.management_endpoints.key_management_endpoints import (
_is_allowed_to_create_key, _is_allowed_to_make_key_request,
) )
assert ( assert (
_is_allowed_to_create_key( _is_allowed_to_make_key_request(
user_api_key_dict=UserAPIKeyAuth( user_api_key_dict=UserAPIKeyAuth(
user_id="test_user_id", user_role=LitellmUserRoles.PROXY_ADMIN user_id="test_user_id", user_role=LitellmUserRoles.PROXY_ADMIN
), ),
@ -1234,7 +1234,7 @@ def test_is_allowed_to_create_key():
) )
assert ( assert (
_is_allowed_to_create_key( _is_allowed_to_make_key_request(
user_api_key_dict=UserAPIKeyAuth( user_api_key_dict=UserAPIKeyAuth(
user_id="test_user_id", user_id="test_user_id",
user_role=LitellmUserRoles.INTERNAL_USER, user_role=LitellmUserRoles.INTERNAL_USER,
@ -1553,6 +1553,7 @@ async def test_spend_logs_cleanup_after_error():
mock_client.spend_log_transactions == original_logs[100:] mock_client.spend_log_transactions == original_logs[100:]
), "Should remove processed logs even after error" ), "Should remove processed logs even after error"
def test_provider_specific_header(): def test_provider_specific_header():
from litellm.proxy.litellm_pre_call_utils import ( from litellm.proxy.litellm_pre_call_utils import (
add_provider_specific_headers_to_request, add_provider_specific_headers_to_request,

View file

@ -315,3 +315,142 @@ async def test_user_model_access():
key=key, key=key,
model="groq/claude-3-5-haiku-20241022", model="groq/claude-3-5-haiku-20241022",
) )
import json
import uuid
import pytest
import aiohttp
from typing import Dict, Tuple
async def setup_test_users(session: aiohttp.ClientSession) -> Tuple[Dict, Dict]:
"""
Create two test users and an additional key for the first user.
Returns tuple of (user1_data, user2_data) where each contains user info and keys.
"""
# Create two test users
user1 = await new_user(
session=session,
i=0,
budget=100,
budget_duration="30d",
models=["anthropic.claude-3-5-sonnet-20240620-v1:0"],
)
user2 = await new_user(
session=session,
i=1,
budget=100,
budget_duration="30d",
models=["anthropic.claude-3-5-sonnet-20240620-v1:0"],
)
print("\nCreated two test users:")
print(f"User 1 ID: {user1['user_id']}")
print(f"User 2 ID: {user2['user_id']}")
# Create an additional key for user1
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {user1['key']}",
}
key_payload = {
"user_id": user1["user_id"],
"duration": "7d",
"key_alias": f"test_key_{uuid.uuid4()}",
"models": ["anthropic.claude-3-5-sonnet-20240620-v1:0"],
}
print("\nGenerating additional key for user1...")
key_response = await session.post(
f"http://0.0.0.0:4000/key/generate", headers=headers, json=key_payload
)
assert key_response.status == 200, "Failed to generate additional key for user1"
user1_additional_key = await key_response.json()
print(f"\nGenerated key details:")
print(json.dumps(user1_additional_key, indent=2))
# Return both users' data including the additional key
return {
"user_data": user1,
"additional_key": user1_additional_key,
"headers": headers,
}, {
"user_data": user2,
"headers": {
"Content-Type": "application/json",
"Authorization": f"Bearer {user2['key']}",
},
}
async def print_response_details(response: aiohttp.ClientResponse) -> None:
"""Helper function to print response details"""
print("\nResponse Details:")
print(f"Status Code: {response.status}")
print("\nResponse Content:")
try:
formatted_json = json.dumps(await response.json(), indent=2)
print(formatted_json)
except json.JSONDecodeError:
print(await response.text())
@pytest.mark.asyncio
async def test_key_update_user_isolation():
"""Test that a user cannot update a key that belongs to another user"""
async with aiohttp.ClientSession() as session:
user1_data, user2_data = await setup_test_users(session)
# Try to update the key to belong to user2
update_payload = {
"key": user1_data["additional_key"]["key"],
"user_id": user2_data["user_data"][
"user_id"
], # Attempting to change ownership
"metadata": {"purpose": "testing_user_isolation", "environment": "test"},
}
print("\nAttempting to update key ownership to user2...")
update_response = await session.post(
f"http://0.0.0.0:4000/key/update",
headers=user1_data["headers"], # Using user1's headers
json=update_payload,
)
await print_response_details(update_response)
# Verify update attempt was rejected
assert (
update_response.status == 403
), "Request should have been rejected with 403 status code"
@pytest.mark.asyncio
async def test_key_delete_user_isolation():
"""Test that a user cannot delete a key that belongs to another user"""
async with aiohttp.ClientSession() as session:
user1_data, user2_data = await setup_test_users(session)
# Try to delete user1's additional key using user2's credentials
delete_payload = {
"keys": [user1_data["additional_key"]["key"]],
}
print("\nAttempting to delete user1's key using user2's credentials...")
delete_response = await session.post(
f"http://0.0.0.0:4000/key/delete",
headers=user2_data["headers"],
json=delete_payload,
)
await print_response_details(delete_response)
# Verify delete attempt was rejected
assert (
delete_response.status == 403
), "Request should have been rejected with 403 status code"