mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
Internal User Endpoint - vulnerability fix + response type fix (#8228)
* fix(key_management_endpoints.py): fix vulnerability where a user could update another user's keys Resolves https://github.com/BerriAI/litellm/issues/8031 * test(key_management_endpoints.py): return consistent 403 forbidden error when modifying key that doesn't belong to user * fix(internal_user_endpoints.py): return model max budget in internal user create response Fixes https://github.com/BerriAI/litellm/issues/7047 * test: fix test * test: update test to handle gemini token counter change * fix(factory.py): fix bedrock http:// handling * docs: fix typo in lm_studio.md (#8222) * test: fix testing * test: fix test --------- Co-authored-by: foreign-sub <51928805+foreign-sub@users.noreply.github.com>
This commit is contained in:
parent
f6bd48a1c5
commit
df93debbc7
7 changed files with 240 additions and 28 deletions
|
@ -787,6 +787,7 @@ class NewUserResponse(GenerateKeyResponse):
|
||||||
] = None
|
] = None
|
||||||
teams: Optional[list] = None
|
teams: Optional[list] = None
|
||||||
user_alias: Optional[str] = None
|
user_alias: Optional[str] = None
|
||||||
|
model_max_budget: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
class UpdateUserRequest(GenerateRequestBase):
|
class UpdateUserRequest(GenerateRequestBase):
|
||||||
|
|
|
@ -142,7 +142,6 @@ async def new_user(
|
||||||
data_json = data.json() # type: ignore
|
data_json = data.json() # type: ignore
|
||||||
data_json = _update_internal_new_user_params(data_json, data)
|
data_json = _update_internal_new_user_params(data_json, data)
|
||||||
response = await generate_key_helper_fn(request_type="user", **data_json)
|
response = await generate_key_helper_fn(request_type="user", **data_json)
|
||||||
|
|
||||||
# Admin UI Logic
|
# Admin UI Logic
|
||||||
# Add User to Team and Organization
|
# Add User to Team and Organization
|
||||||
# if team_id passed add this user to the team
|
# if team_id passed add this user to the team
|
||||||
|
@ -220,6 +219,7 @@ async def new_user(
|
||||||
tpm_limit=response.get("tpm_limit", None),
|
tpm_limit=response.get("tpm_limit", None),
|
||||||
rpm_limit=response.get("rpm_limit", None),
|
rpm_limit=response.get("rpm_limit", None),
|
||||||
budget_duration=response.get("budget_duration", None),
|
budget_duration=response.get("budget_duration", None),
|
||||||
|
model_max_budget=response.get("model_max_budget", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -65,7 +65,7 @@ def _get_user_in_team(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _is_allowed_to_create_key(
|
def _is_allowed_to_make_key_request(
|
||||||
user_api_key_dict: UserAPIKeyAuth, user_id: Optional[str], team_id: Optional[str]
|
user_api_key_dict: UserAPIKeyAuth, user_id: Optional[str], team_id: Optional[str]
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
|
@ -266,6 +266,40 @@ def key_generation_check(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def common_key_access_checks(
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
data: Union[GenerateKeyRequest, UpdateKeyRequest],
|
||||||
|
llm_router: Optional[Router],
|
||||||
|
premium_user: bool,
|
||||||
|
) -> Literal[True]:
|
||||||
|
"""
|
||||||
|
Check if user is allowed to make a key request, for this key
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
_is_allowed_to_make_key_request(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
user_id=data.user_id,
|
||||||
|
team_id=data.team_id,
|
||||||
|
)
|
||||||
|
except AssertionError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail=str(e),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
_check_model_access_group(
|
||||||
|
models=data.models,
|
||||||
|
llm_router=llm_router,
|
||||||
|
premium_user=premium_user,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@ -381,25 +415,9 @@ async def generate_key_fn( # noqa: PLR0915
|
||||||
data=data,
|
data=data,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
common_key_access_checks(
|
||||||
_is_allowed_to_create_key(
|
user_api_key_dict=user_api_key_dict,
|
||||||
user_api_key_dict=user_api_key_dict,
|
data=data,
|
||||||
user_id=data.user_id,
|
|
||||||
team_id=data.team_id,
|
|
||||||
)
|
|
||||||
except AssertionError as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=403,
|
|
||||||
detail=str(e),
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail=str(e),
|
|
||||||
)
|
|
||||||
|
|
||||||
_check_model_access_group(
|
|
||||||
models=data.models,
|
|
||||||
llm_router=llm_router,
|
llm_router=llm_router,
|
||||||
premium_user=premium_user,
|
premium_user=premium_user,
|
||||||
)
|
)
|
||||||
|
@ -684,6 +702,8 @@ async def update_key_fn(
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
from litellm.proxy.proxy_server import (
|
from litellm.proxy.proxy_server import (
|
||||||
|
llm_router,
|
||||||
|
premium_user,
|
||||||
prisma_client,
|
prisma_client,
|
||||||
proxy_logging_obj,
|
proxy_logging_obj,
|
||||||
user_api_key_cache,
|
user_api_key_cache,
|
||||||
|
@ -692,10 +712,18 @@ async def update_key_fn(
|
||||||
try:
|
try:
|
||||||
data_json: dict = data.model_dump(exclude_unset=True, exclude_none=True)
|
data_json: dict = data.model_dump(exclude_unset=True, exclude_none=True)
|
||||||
key = data_json.pop("key")
|
key = data_json.pop("key")
|
||||||
|
|
||||||
# get the row from db
|
# get the row from db
|
||||||
if prisma_client is None:
|
if prisma_client is None:
|
||||||
raise Exception("Not connected to DB!")
|
raise Exception("Not connected to DB!")
|
||||||
|
|
||||||
|
common_key_access_checks(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
data=data,
|
||||||
|
llm_router=llm_router,
|
||||||
|
premium_user=premium_user,
|
||||||
|
)
|
||||||
|
|
||||||
existing_key_row = await prisma_client.get_data(
|
existing_key_row = await prisma_client.get_data(
|
||||||
token=data.key, table_name="key", query_type="find_unique"
|
token=data.key, table_name="key", query_type="find_unique"
|
||||||
)
|
)
|
||||||
|
@ -1412,6 +1440,13 @@ async def delete_verification_tokens(
|
||||||
):
|
):
|
||||||
await prisma_client.delete_data(tokens=[key.token])
|
await prisma_client.delete_data(tokens=[key.token])
|
||||||
deleted_tokens.append(key.token)
|
deleted_tokens.append(key.token)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail={
|
||||||
|
"error": "You are not authorized to delete this key"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
tasks.append(_delete_key(key))
|
tasks.append(_delete_key(key))
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
|
@ -842,7 +842,15 @@ async def test_key_update_with_model_specific_params(prisma_client):
|
||||||
"litellm_budget_table": None,
|
"litellm_budget_table": None,
|
||||||
"token": token_hash,
|
"token": token_hash,
|
||||||
}
|
}
|
||||||
await update_key_fn(request=request, data=UpdateKeyRequest(**args))
|
await update_key_fn(
|
||||||
|
request=request,
|
||||||
|
data=UpdateKeyRequest(**args),
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
api_key="sk-1234",
|
||||||
|
user_id="1234",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
|
@ -1314,6 +1314,11 @@ def test_generate_and_update_key(prisma_client):
|
||||||
budget_duration="1mo",
|
budget_duration="1mo",
|
||||||
max_budget=100,
|
max_budget=100,
|
||||||
),
|
),
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
api_key="sk-1234",
|
||||||
|
user_id="1234",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
print("response1=", response1)
|
print("response1=", response1)
|
||||||
|
@ -1322,6 +1327,11 @@ def test_generate_and_update_key(prisma_client):
|
||||||
response2 = await update_key_fn(
|
response2 = await update_key_fn(
|
||||||
request=Request,
|
request=Request,
|
||||||
data=UpdateKeyRequest(key=generated_key, team_id=_team_2),
|
data=UpdateKeyRequest(key=generated_key, team_id=_team_2),
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
api_key="sk-1234",
|
||||||
|
user_id="1234",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
print("response2=", response2)
|
print("response2=", response2)
|
||||||
|
|
||||||
|
@ -2956,7 +2966,11 @@ async def test_generate_key_with_model_tpm_limit(prisma_client):
|
||||||
_request = Request(scope={"type": "http"})
|
_request = Request(scope={"type": "http"})
|
||||||
_request._url = URL(url="/update/key")
|
_request._url = URL(url="/update/key")
|
||||||
|
|
||||||
await update_key_fn(data=request, request=_request)
|
await update_key_fn(
|
||||||
|
data=request,
|
||||||
|
request=_request,
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
|
||||||
|
)
|
||||||
result = await info_key_fn(
|
result = await info_key_fn(
|
||||||
key=generated_key,
|
key=generated_key,
|
||||||
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
|
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
|
||||||
|
@ -3017,7 +3031,11 @@ async def test_generate_key_with_guardrails(prisma_client):
|
||||||
_request = Request(scope={"type": "http"})
|
_request = Request(scope={"type": "http"})
|
||||||
_request._url = URL(url="/update/key")
|
_request._url = URL(url="/update/key")
|
||||||
|
|
||||||
await update_key_fn(data=request, request=_request)
|
await update_key_fn(
|
||||||
|
data=request,
|
||||||
|
request=_request,
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
|
||||||
|
)
|
||||||
result = await info_key_fn(
|
result = await info_key_fn(
|
||||||
key=generated_key,
|
key=generated_key,
|
||||||
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
|
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
|
||||||
|
@ -3710,6 +3728,11 @@ async def test_key_alias_uniqueness(prisma_client):
|
||||||
await update_key_fn(
|
await update_key_fn(
|
||||||
data=UpdateKeyRequest(key=key3.key, key_alias=unique_alias),
|
data=UpdateKeyRequest(key=key3.key, key_alias=unique_alias),
|
||||||
request=Request(scope={"type": "http"}),
|
request=Request(scope={"type": "http"}),
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
api_key="sk-1234",
|
||||||
|
user_id="1234",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
pytest.fail("Should not be able to update a key to use an existing alias")
|
pytest.fail("Should not be able to update a key to use an existing alias")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -3719,6 +3742,11 @@ async def test_key_alias_uniqueness(prisma_client):
|
||||||
updated_key = await update_key_fn(
|
updated_key = await update_key_fn(
|
||||||
data=UpdateKeyRequest(key=key1.key, key_alias=unique_alias),
|
data=UpdateKeyRequest(key=key1.key, key_alias=unique_alias),
|
||||||
request=Request(scope={"type": "http"}),
|
request=Request(scope={"type": "http"}),
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
api_key="sk-1234",
|
||||||
|
user_id="1234",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
assert updated_key is not None
|
assert updated_key is not None
|
||||||
|
|
||||||
|
|
|
@ -1216,14 +1216,14 @@ def test_litellm_verification_token_view_response_with_budget_table(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_is_allowed_to_create_key():
|
def test_is_allowed_to_make_key_request():
|
||||||
from litellm.proxy._types import LitellmUserRoles
|
from litellm.proxy._types import LitellmUserRoles
|
||||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||||
_is_allowed_to_create_key,
|
_is_allowed_to_make_key_request,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
_is_allowed_to_create_key(
|
_is_allowed_to_make_key_request(
|
||||||
user_api_key_dict=UserAPIKeyAuth(
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
user_id="test_user_id", user_role=LitellmUserRoles.PROXY_ADMIN
|
user_id="test_user_id", user_role=LitellmUserRoles.PROXY_ADMIN
|
||||||
),
|
),
|
||||||
|
@ -1234,7 +1234,7 @@ def test_is_allowed_to_create_key():
|
||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
_is_allowed_to_create_key(
|
_is_allowed_to_make_key_request(
|
||||||
user_api_key_dict=UserAPIKeyAuth(
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
user_id="test_user_id",
|
user_id="test_user_id",
|
||||||
user_role=LitellmUserRoles.INTERNAL_USER,
|
user_role=LitellmUserRoles.INTERNAL_USER,
|
||||||
|
@ -1553,6 +1553,7 @@ async def test_spend_logs_cleanup_after_error():
|
||||||
mock_client.spend_log_transactions == original_logs[100:]
|
mock_client.spend_log_transactions == original_logs[100:]
|
||||||
), "Should remove processed logs even after error"
|
), "Should remove processed logs even after error"
|
||||||
|
|
||||||
|
|
||||||
def test_provider_specific_header():
|
def test_provider_specific_header():
|
||||||
from litellm.proxy.litellm_pre_call_utils import (
|
from litellm.proxy.litellm_pre_call_utils import (
|
||||||
add_provider_specific_headers_to_request,
|
add_provider_specific_headers_to_request,
|
||||||
|
|
|
@ -315,3 +315,142 @@ async def test_user_model_access():
|
||||||
key=key,
|
key=key,
|
||||||
model="groq/claude-3-5-haiku-20241022",
|
model="groq/claude-3-5-haiku-20241022",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
import pytest
|
||||||
|
import aiohttp
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
async def setup_test_users(session: aiohttp.ClientSession) -> Tuple[Dict, Dict]:
|
||||||
|
"""
|
||||||
|
Create two test users and an additional key for the first user.
|
||||||
|
Returns tuple of (user1_data, user2_data) where each contains user info and keys.
|
||||||
|
"""
|
||||||
|
# Create two test users
|
||||||
|
user1 = await new_user(
|
||||||
|
session=session,
|
||||||
|
i=0,
|
||||||
|
budget=100,
|
||||||
|
budget_duration="30d",
|
||||||
|
models=["anthropic.claude-3-5-sonnet-20240620-v1:0"],
|
||||||
|
)
|
||||||
|
|
||||||
|
user2 = await new_user(
|
||||||
|
session=session,
|
||||||
|
i=1,
|
||||||
|
budget=100,
|
||||||
|
budget_duration="30d",
|
||||||
|
models=["anthropic.claude-3-5-sonnet-20240620-v1:0"],
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\nCreated two test users:")
|
||||||
|
print(f"User 1 ID: {user1['user_id']}")
|
||||||
|
print(f"User 2 ID: {user2['user_id']}")
|
||||||
|
|
||||||
|
# Create an additional key for user1
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {user1['key']}",
|
||||||
|
}
|
||||||
|
|
||||||
|
key_payload = {
|
||||||
|
"user_id": user1["user_id"],
|
||||||
|
"duration": "7d",
|
||||||
|
"key_alias": f"test_key_{uuid.uuid4()}",
|
||||||
|
"models": ["anthropic.claude-3-5-sonnet-20240620-v1:0"],
|
||||||
|
}
|
||||||
|
|
||||||
|
print("\nGenerating additional key for user1...")
|
||||||
|
key_response = await session.post(
|
||||||
|
f"http://0.0.0.0:4000/key/generate", headers=headers, json=key_payload
|
||||||
|
)
|
||||||
|
|
||||||
|
assert key_response.status == 200, "Failed to generate additional key for user1"
|
||||||
|
user1_additional_key = await key_response.json()
|
||||||
|
|
||||||
|
print(f"\nGenerated key details:")
|
||||||
|
print(json.dumps(user1_additional_key, indent=2))
|
||||||
|
|
||||||
|
# Return both users' data including the additional key
|
||||||
|
return {
|
||||||
|
"user_data": user1,
|
||||||
|
"additional_key": user1_additional_key,
|
||||||
|
"headers": headers,
|
||||||
|
}, {
|
||||||
|
"user_data": user2,
|
||||||
|
"headers": {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {user2['key']}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def print_response_details(response: aiohttp.ClientResponse) -> None:
|
||||||
|
"""Helper function to print response details"""
|
||||||
|
print("\nResponse Details:")
|
||||||
|
print(f"Status Code: {response.status}")
|
||||||
|
print("\nResponse Content:")
|
||||||
|
try:
|
||||||
|
formatted_json = json.dumps(await response.json(), indent=2)
|
||||||
|
print(formatted_json)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print(await response.text())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_key_update_user_isolation():
|
||||||
|
"""Test that a user cannot update a key that belongs to another user"""
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
user1_data, user2_data = await setup_test_users(session)
|
||||||
|
|
||||||
|
# Try to update the key to belong to user2
|
||||||
|
update_payload = {
|
||||||
|
"key": user1_data["additional_key"]["key"],
|
||||||
|
"user_id": user2_data["user_data"][
|
||||||
|
"user_id"
|
||||||
|
], # Attempting to change ownership
|
||||||
|
"metadata": {"purpose": "testing_user_isolation", "environment": "test"},
|
||||||
|
}
|
||||||
|
|
||||||
|
print("\nAttempting to update key ownership to user2...")
|
||||||
|
update_response = await session.post(
|
||||||
|
f"http://0.0.0.0:4000/key/update",
|
||||||
|
headers=user1_data["headers"], # Using user1's headers
|
||||||
|
json=update_payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
await print_response_details(update_response)
|
||||||
|
|
||||||
|
# Verify update attempt was rejected
|
||||||
|
assert (
|
||||||
|
update_response.status == 403
|
||||||
|
), "Request should have been rejected with 403 status code"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_key_delete_user_isolation():
|
||||||
|
"""Test that a user cannot delete a key that belongs to another user"""
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
user1_data, user2_data = await setup_test_users(session)
|
||||||
|
|
||||||
|
# Try to delete user1's additional key using user2's credentials
|
||||||
|
delete_payload = {
|
||||||
|
"keys": [user1_data["additional_key"]["key"]],
|
||||||
|
}
|
||||||
|
|
||||||
|
print("\nAttempting to delete user1's key using user2's credentials...")
|
||||||
|
delete_response = await session.post(
|
||||||
|
f"http://0.0.0.0:4000/key/delete",
|
||||||
|
headers=user2_data["headers"],
|
||||||
|
json=delete_payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
await print_response_details(delete_response)
|
||||||
|
|
||||||
|
# Verify delete attempt was rejected
|
||||||
|
assert (
|
||||||
|
delete_response.status == 403
|
||||||
|
), "Request should have been rejected with 403 status code"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue