From 0c3b7bb37ddabce3272216b2d810dd0f5356d9f7 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Mon, 21 Apr 2025 16:17:45 -0700 Subject: [PATCH] =?UTF-8?q?fix(router.py):=20handle=20edge=20case=20where?= =?UTF-8?q?=20user=20sets=20'model=5Fgroup'=20inside=E2=80=A6=20(#10191)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(router.py): handle edge case where user sets 'model_group' inside 'model_info' * fix(key_management_endpoints.py): security fix - return hashed token in 'token' field Ensures when creating a key on UI - only hashed token shown * test(test_key_management_endpoints.py): add unit test * test: update test --- .../key_management_endpoints.py | 42 ++++++++------- litellm/router.py | 8 ++- .../test_key_management_endpoints.py | 51 +++++++++++++++++++ tests/litellm/test_router.py | 26 ++++++++++ .../test_adding_passthrough_model.py | 2 +- 5 files changed, 107 insertions(+), 22 deletions(-) diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index ad9813aaf8..8fd3b555d4 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -577,12 +577,16 @@ async def generate_key_fn( # noqa: PLR0915 request_type="key", **data_json, table_name="key" ) - response["soft_budget"] = ( - data.soft_budget - ) # include the user-input soft budget in the response + response[ + "soft_budget" + ] = data.soft_budget # include the user-input soft budget in the response response = GenerateKeyResponse(**response) + response.token = ( + response.token_id + ) # remap token to use the hash, and leave the key in the `key` field [TODO]: clean up generate_key_helper_fn to do this + asyncio.create_task( KeyManagementEventHooks.async_key_generated_hook( data=data, @@ -1470,10 +1474,10 @@ async def delete_verification_tokens( try: if prisma_client: tokens = [_hash_token_if_needed(token=key) for key in tokens] - _keys_being_deleted: List[LiteLLM_VerificationToken] = ( - await prisma_client.db.litellm_verificationtoken.find_many( - where={"token": {"in": tokens}} - ) + _keys_being_deleted: List[ + LiteLLM_VerificationToken + ] = await prisma_client.db.litellm_verificationtoken.find_many( + where={"token": {"in": tokens}} ) # Assuming 'db' is your Prisma Client instance @@ -1575,9 +1579,9 @@ async def _rotate_master_key( from litellm.proxy.proxy_server import proxy_config try: - models: Optional[List] = ( - await prisma_client.db.litellm_proxymodeltable.find_many() - ) + models: Optional[ + List + ] = await prisma_client.db.litellm_proxymodeltable.find_many() except Exception: models = None # 2. process model table @@ -1864,11 +1868,11 @@ async def validate_key_list_check( param="user_id", code=status.HTTP_403_FORBIDDEN, ) - complete_user_info_db_obj: Optional[BaseModel] = ( - await prisma_client.db.litellm_usertable.find_unique( - where={"user_id": user_api_key_dict.user_id}, - include={"organization_memberships": True}, - ) + complete_user_info_db_obj: Optional[ + BaseModel + ] = await prisma_client.db.litellm_usertable.find_unique( + where={"user_id": user_api_key_dict.user_id}, + include={"organization_memberships": True}, ) if complete_user_info_db_obj is None: @@ -1929,10 +1933,10 @@ async def get_admin_team_ids( if complete_user_info is None: return [] # Get all teams that user is an admin of - teams: Optional[List[BaseModel]] = ( - await prisma_client.db.litellm_teamtable.find_many( - where={"team_id": {"in": complete_user_info.teams}} - ) + teams: Optional[ + List[BaseModel] + ] = await prisma_client.db.litellm_teamtable.find_many( + where={"team_id": {"in": complete_user_info.teams}} ) if teams is None: return [] diff --git a/litellm/router.py b/litellm/router.py index bf1bca0a75..6deffa9761 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -4983,8 +4983,12 @@ class Router: ) if model_group_info is None: - model_group_info = ModelGroupInfo( - model_group=user_facing_model_group_name, providers=[llm_provider], **model_info # type: ignore + model_group_info = ModelGroupInfo( # type: ignore + **{ + "model_group": user_facing_model_group_name, + "providers": [llm_provider], + **model_info, + } ) else: # if max_input_tokens > curr diff --git a/tests/litellm/proxy/management_endpoints/test_key_management_endpoints.py b/tests/litellm/proxy/management_endpoints/test_key_management_endpoints.py index 51bbbb49c4..c436e08901 100644 --- a/tests/litellm/proxy/management_endpoints/test_key_management_endpoints.py +++ b/tests/litellm/proxy/management_endpoints/test_key_management_endpoints.py @@ -46,3 +46,54 @@ async def test_list_keys(): assert json.dumps({"team_id": {"not": "litellm-dashboard"}}) in json.dumps( where_condition ) + + +@pytest.mark.asyncio +async def test_key_token_handling(monkeypatch): + """ + Test that token handling in key generation follows the expected behavior: + 1. token field should not equal key field + 2. if token_id exists, it should equal token field + """ + mock_prisma_client = AsyncMock() + mock_insert_data = AsyncMock( + return_value=MagicMock(token="hashed_token_123", litellm_budget_table=None) + ) + mock_prisma_client.insert_data = mock_insert_data + mock_prisma_client.db = MagicMock() + mock_prisma_client.db.litellm_verificationtoken = MagicMock() + mock_prisma_client.db.litellm_verificationtoken.find_unique = AsyncMock( + return_value=None + ) + mock_prisma_client.db.litellm_verificationtoken.find_many = AsyncMock( + return_value=[] + ) + mock_prisma_client.db.litellm_verificationtoken.count = AsyncMock(return_value=0) + mock_prisma_client.db.litellm_verificationtoken.update = AsyncMock( + return_value=MagicMock(token="hashed_token_123", litellm_budget_table=None) + ) + + from litellm.proxy._types import GenerateKeyRequest, LitellmUserRoles + from litellm.proxy.auth.user_api_key_auth import UserAPIKeyAuth + from litellm.proxy.management_endpoints.key_management_endpoints import ( + generate_key_fn, + ) + from litellm.proxy.proxy_server import prisma_client + + # Use monkeypatch to set the prisma_client + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client) + + # Test key generation + response = await generate_key_fn( + data=GenerateKeyRequest(), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="1234" + ), + ) + + # Verify token handling + assert response.key != response.token, "Token should not equal key" + if hasattr(response, "token_id"): + assert ( + response.token == response.token_id + ), "Token should equal token_id if token_id exists" diff --git a/tests/litellm/test_router.py b/tests/litellm/test_router.py index 66a420e79d..3a572d861d 100644 --- a/tests/litellm/test_router.py +++ b/tests/litellm/test_router.py @@ -52,3 +52,29 @@ def test_update_kwargs_does_not_mutate_defaults_and_merges_metadata(): # 3) metadata lands under "metadata" assert kwargs["litellm_metadata"] == {"baz": 123} + + +def test_router_with_model_info_and_model_group(): + """ + Test edge case where user specifies model_group in model_info + """ + router = litellm.Router( + model_list=[ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + }, + "model_info": { + "tpm": 1000, + "rpm": 1000, + "model_group": "gpt-3.5-turbo", + }, + } + ], + ) + + router._set_model_group_info( + model_group="gpt-3.5-turbo", + user_facing_model_group_name="gpt-3.5-turbo", + ) diff --git a/tests/store_model_in_db_tests/test_adding_passthrough_model.py b/tests/store_model_in_db_tests/test_adding_passthrough_model.py index ad26e19bd6..e901be5bd7 100644 --- a/tests/store_model_in_db_tests/test_adding_passthrough_model.py +++ b/tests/store_model_in_db_tests/test_adding_passthrough_model.py @@ -142,7 +142,7 @@ def create_virtual_key(): json={}, ) print(response.json()) - return response.json()["token"] + return response.json()["key"] def add_assembly_ai_model_to_db(