diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 84334b1db9..852a6f9933 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -790,21 +790,6 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 raise Exception( "Key is blocked. Update via `/key/unblock` if you're admin." ) - - # Check 1. If token can call model - _model_alias_map = {} - model: Optional[str] = None - if ( - hasattr(valid_token, "team_model_aliases") - and valid_token.team_model_aliases is not None - ): - _model_alias_map = { - **valid_token.aliases, - **valid_token.team_model_aliases, - } - else: - _model_alias_map = {**valid_token.aliases} - litellm.model_alias_map = _model_alias_map config = valid_token.config if config != {}: diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 5892b7afc6..dab8b4a715 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -635,6 +635,12 @@ async def add_litellm_data_to_request( # noqa: PLR0915 user_api_key_dict=user_api_key_dict, ) + # Team Model Aliases + _update_model_if_team_alias_exists( + data=data, + user_api_key_dict=user_api_key_dict, + ) + verbose_proxy_logger.debug( "[PROXY] returned data from litellm_pre_call_utils: %s", data ) @@ -664,6 +670,32 @@ async def add_litellm_data_to_request( # noqa: PLR0915 return data +def _update_model_if_team_alias_exists( + data: dict, + user_api_key_dict: UserAPIKeyAuth, +) -> None: + """ + Update the model if the team alias exists + + If a alias map has been set on a team, then we want to make the request with the model the team alias is pointing to + + eg. + - user calls `gpt-4o` + - team.model_alias_map = { + "gpt-4o": "gpt-4o-team-1" + } + - requested_model = "gpt-4o-team-1" + """ + _model = data.get("model") + if ( + _model + and user_api_key_dict.team_model_aliases + and _model in user_api_key_dict.team_model_aliases + ): + data["model"] = user_api_key_dict.team_model_aliases[_model] + return + + def _get_enforced_params( general_settings: Optional[dict], user_api_key_dict: UserAPIKeyAuth ) -> Optional[list]: diff --git a/tests/proxy_unit_tests/test_proxy_utils.py b/tests/proxy_unit_tests/test_proxy_utils.py index 36f9b6652f..6dc03916a7 100644 --- a/tests/proxy_unit_tests/test_proxy_utils.py +++ b/tests/proxy_unit_tests/test_proxy_utils.py @@ -1617,3 +1617,53 @@ def test_provider_specific_header(): "anthropic-beta": "prompt-caching-2024-07-31", }, } + + +@pytest.mark.parametrize( + "data, user_api_key_dict, expected_model", + [ + # Test case 1: Model exists in team aliases + ( + {"model": "gpt-4o"}, + UserAPIKeyAuth( + api_key="test_key", team_model_aliases={"gpt-4o": "gpt-4o-team-1"} + ), + "gpt-4o-team-1", + ), + # Test case 2: Model doesn't exist in team aliases + ( + {"model": "gpt-4o"}, + UserAPIKeyAuth( + api_key="test_key", team_model_aliases={"claude-3": "claude-3-team-1"} + ), + "gpt-4o", + ), + # Test case 3: No team aliases defined + ( + {"model": "gpt-4o"}, + UserAPIKeyAuth(api_key="test_key", team_model_aliases=None), + "gpt-4o", + ), + # Test case 4: No model in request data + ( + {"messages": []}, + UserAPIKeyAuth( + api_key="test_key", team_model_aliases={"gpt-4o": "gpt-4o-team-1"} + ), + None, + ), + ], +) +def test_update_model_if_team_alias_exists(data, user_api_key_dict, expected_model): + from litellm.proxy.litellm_pre_call_utils import _update_model_if_team_alias_exists + + # Make a copy of the input data to avoid modifying the test parameters + test_data = data.copy() + + # Call the function + _update_model_if_team_alias_exists( + data=test_data, user_api_key_dict=user_api_key_dict + ) + + # Check if model was updated correctly + assert test_data.get("model") == expected_model diff --git a/tests/store_model_in_db_tests/test_team_alias.py b/tests/store_model_in_db_tests/test_team_alias.py new file mode 100644 index 0000000000..11c65dfdc5 --- /dev/null +++ b/tests/store_model_in_db_tests/test_team_alias.py @@ -0,0 +1,86 @@ +import pytest +import asyncio +import aiohttp +import json +from openai import AsyncOpenAI +import uuid +from httpx import AsyncClient +import uuid +import os + +TEST_MASTER_KEY = "sk-1234" +PROXY_BASE_URL = "http://0.0.0.0:4000" + + +@pytest.mark.asyncio +async def test_team_model_alias(): + """ + Test model alias functionality with teams: + 1. Add a new model with model_name="gpt-4-team1" and litellm_params.model="gpt-4o" + 2. Create a new team + 3. Update team with model_alias mapping + 4. Generate key for team + 5. Make request with aliased model name + """ + client = AsyncClient(base_url=PROXY_BASE_URL) + headers = {"Authorization": f"Bearer {TEST_MASTER_KEY}"} + + # Add new model + model_response = await client.post( + "/model/new", + json={ + "model_name": "gpt-4o-team1", + "litellm_params": { + "model": "gpt-4o", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + }, + headers=headers, + ) + assert model_response.status_code == 200 + + # Create new team + team_response = await client.post( + "/team/new", + json={ + "models": ["gpt-4o-team1"], + }, + headers=headers, + ) + assert team_response.status_code == 200 + team_data = team_response.json() + team_id = team_data["team_id"] + + # Update team with model alias + update_response = await client.post( + "/team/update", + json={"team_id": team_id, "model_aliases": {"gpt-4o": "gpt-4o-team1"}}, + headers=headers, + ) + assert update_response.status_code == 200 + + # Generate key for team + key_response = await client.post( + "/key/generate", json={"team_id": team_id}, headers=headers + ) + assert key_response.status_code == 200 + key = key_response.json()["key"] + + # Make request with model alias + openai_client = AsyncOpenAI(api_key=key, base_url=f"{PROXY_BASE_URL}/v1") + + response = await openai_client.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": f"Test message {uuid.uuid4()}"}], + ) + + assert response is not None, "Should get valid response when using model alias" + + # Cleanup - delete the model + model_id = model_response.json()["model_info"]["id"] + delete_response = await client.post( + "/model/delete", + json={"id": model_id}, + headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"}, + ) + assert delete_response.status_code == 200