fix(router.py): handle edge case where user sets 'model_group' inside… (#10191)

* fix(router.py): handle edge case where user sets 'model_group' inside 'model_info'

* fix(key_management_endpoints.py): security fix - return hashed token in 'token' field

Ensures when creating a key on UI - only hashed token shown

* test(test_key_management_endpoints.py): add unit test

* test: update test
This commit is contained in:
Krish Dholakia 2025-04-21 16:17:45 -07:00 committed by GitHub
parent 28467252c0
commit 539ca4f620
5 changed files with 107 additions and 22 deletions

View file

@ -577,12 +577,16 @@ async def generate_key_fn( # noqa: PLR0915
request_type="key", **data_json, table_name="key"
)
response["soft_budget"] = (
data.soft_budget
) # include the user-input soft budget in the response
response[
"soft_budget"
] = data.soft_budget # include the user-input soft budget in the response
response = GenerateKeyResponse(**response)
response.token = (
response.token_id
) # remap token to use the hash, and leave the key in the `key` field [TODO]: clean up generate_key_helper_fn to do this
asyncio.create_task(
KeyManagementEventHooks.async_key_generated_hook(
data=data,
@ -1470,10 +1474,10 @@ async def delete_verification_tokens(
try:
if prisma_client:
tokens = [_hash_token_if_needed(token=key) for key in tokens]
_keys_being_deleted: List[LiteLLM_VerificationToken] = (
await prisma_client.db.litellm_verificationtoken.find_many(
where={"token": {"in": tokens}}
)
_keys_being_deleted: List[
LiteLLM_VerificationToken
] = await prisma_client.db.litellm_verificationtoken.find_many(
where={"token": {"in": tokens}}
)
# Assuming 'db' is your Prisma Client instance
@ -1575,9 +1579,9 @@ async def _rotate_master_key(
from litellm.proxy.proxy_server import proxy_config
try:
models: Optional[List] = (
await prisma_client.db.litellm_proxymodeltable.find_many()
)
models: Optional[
List
] = await prisma_client.db.litellm_proxymodeltable.find_many()
except Exception:
models = None
# 2. process model table
@ -1864,11 +1868,11 @@ async def validate_key_list_check(
param="user_id",
code=status.HTTP_403_FORBIDDEN,
)
complete_user_info_db_obj: Optional[BaseModel] = (
await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_api_key_dict.user_id},
include={"organization_memberships": True},
)
complete_user_info_db_obj: Optional[
BaseModel
] = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_api_key_dict.user_id},
include={"organization_memberships": True},
)
if complete_user_info_db_obj is None:
@ -1929,10 +1933,10 @@ async def get_admin_team_ids(
if complete_user_info is None:
return []
# Get all teams that user is an admin of
teams: Optional[List[BaseModel]] = (
await prisma_client.db.litellm_teamtable.find_many(
where={"team_id": {"in": complete_user_info.teams}}
)
teams: Optional[
List[BaseModel]
] = await prisma_client.db.litellm_teamtable.find_many(
where={"team_id": {"in": complete_user_info.teams}}
)
if teams is None:
return []

View file

@ -4983,8 +4983,12 @@ class Router:
)
if model_group_info is None:
model_group_info = ModelGroupInfo(
model_group=user_facing_model_group_name, providers=[llm_provider], **model_info # type: ignore
model_group_info = ModelGroupInfo( # type: ignore
**{
"model_group": user_facing_model_group_name,
"providers": [llm_provider],
**model_info,
}
)
else:
# if max_input_tokens > curr

View file

@ -46,3 +46,54 @@ async def test_list_keys():
assert json.dumps({"team_id": {"not": "litellm-dashboard"}}) in json.dumps(
where_condition
)
@pytest.mark.asyncio
async def test_key_token_handling(monkeypatch):
"""
Test that token handling in key generation follows the expected behavior:
1. token field should not equal key field
2. if token_id exists, it should equal token field
"""
mock_prisma_client = AsyncMock()
mock_insert_data = AsyncMock(
return_value=MagicMock(token="hashed_token_123", litellm_budget_table=None)
)
mock_prisma_client.insert_data = mock_insert_data
mock_prisma_client.db = MagicMock()
mock_prisma_client.db.litellm_verificationtoken = MagicMock()
mock_prisma_client.db.litellm_verificationtoken.find_unique = AsyncMock(
return_value=None
)
mock_prisma_client.db.litellm_verificationtoken.find_many = AsyncMock(
return_value=[]
)
mock_prisma_client.db.litellm_verificationtoken.count = AsyncMock(return_value=0)
mock_prisma_client.db.litellm_verificationtoken.update = AsyncMock(
return_value=MagicMock(token="hashed_token_123", litellm_budget_table=None)
)
from litellm.proxy._types import GenerateKeyRequest, LitellmUserRoles
from litellm.proxy.auth.user_api_key_auth import UserAPIKeyAuth
from litellm.proxy.management_endpoints.key_management_endpoints import (
generate_key_fn,
)
from litellm.proxy.proxy_server import prisma_client
# Use monkeypatch to set the prisma_client
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
# Test key generation
response = await generate_key_fn(
data=GenerateKeyRequest(),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="1234"
),
)
# Verify token handling
assert response.key != response.token, "Token should not equal key"
if hasattr(response, "token_id"):
assert (
response.token == response.token_id
), "Token should equal token_id if token_id exists"

View file

@ -52,3 +52,29 @@ def test_update_kwargs_does_not_mutate_defaults_and_merges_metadata():
# 3) metadata lands under "metadata"
assert kwargs["litellm_metadata"] == {"baz": 123}
def test_router_with_model_info_and_model_group():
"""
Test edge case where user specifies model_group in model_info
"""
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
},
"model_info": {
"tpm": 1000,
"rpm": 1000,
"model_group": "gpt-3.5-turbo",
},
}
],
)
router._set_model_group_info(
model_group="gpt-3.5-turbo",
user_facing_model_group_name="gpt-3.5-turbo",
)

View file

@ -142,7 +142,7 @@ def create_virtual_key():
json={},
)
print(response.json())
return response.json()["token"]
return response.json()["key"]
def add_assembly_ai_model_to_db(