diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 517cc7c73b..0590bcb50a 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -101,6 +101,7 @@ async def common_checks( team_object=team_object, model=_model, llm_router=llm_router, + team_model_aliases=valid_token.team_model_aliases if valid_token else None, ) ## 2.1 If user can call model (if personal key) @@ -968,6 +969,7 @@ async def _can_object_call_model( model: str, llm_router: Optional[Router], models: List[str], + team_model_aliases: Optional[Dict[str, str]] = None, ) -> Literal[True]: """ Checks if token can call a given model @@ -1002,6 +1004,9 @@ async def _can_object_call_model( verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}") + if _model_in_team_aliases(model=model, team_model_aliases=team_model_aliases): + return True + if _model_matches_any_wildcard_pattern_in_list( model=model, allowed_model_list=filtered_models ): @@ -1026,6 +1031,26 @@ async def _can_object_call_model( return True +def _model_in_team_aliases( + model: str, team_model_aliases: Optional[Dict[str, str]] = None +) -> bool: + """ + Returns True if `model` being accessed is an alias of a team model + + - `model=gpt-4o` + - `team_model_aliases={"gpt-4o": "gpt-4o-team-1"}` + - returns True + + - `model=gp-4o` + - `team_model_aliases={"o-3": "o3-preview"}` + - returns False + """ + if team_model_aliases: + if model in team_model_aliases: + return True + return False + + async def can_key_call_model( model: str, llm_model_list: Optional[list], @@ -1045,6 +1070,7 @@ async def can_key_call_model( model=model, llm_router=llm_router, models=valid_token.models, + team_model_aliases=valid_token.team_model_aliases, ) @@ -1217,6 +1243,7 @@ def _team_model_access_check( model: Optional[str], team_object: Optional[LiteLLM_TeamTable], llm_router: Optional[Router], + team_model_aliases: Optional[Dict[str, str]] = None, ): """ Access check for team models @@ -1244,6 +1271,8 @@ def _team_model_access_check( pass elif model and "*" in model: pass + elif _model_in_team_aliases(model=model, team_model_aliases=team_model_aliases): + pass elif _model_matches_any_wildcard_pattern_in_list( model=model, allowed_model_list=team_object.models ): diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index 35fbfe433e..d9b2ee8646 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -1516,7 +1516,11 @@ async def list_team( detail={"error": CommonProxyErrors.db_not_connected_error.value}, ) - response = await prisma_client.db.litellm_teamtable.find_many() + response = await prisma_client.db.litellm_teamtable.find_many( + include={ + "litellm_model_table": True, + } + ) filtered_response = [] if user_id: diff --git a/tests/proxy_admin_ui_tests/test_key_management.py b/tests/proxy_admin_ui_tests/test_key_management.py index ae80b05b70..4fb94a5462 100644 --- a/tests/proxy_admin_ui_tests/test_key_management.py +++ b/tests/proxy_admin_ui_tests/test_key_management.py @@ -1022,3 +1022,82 @@ async def test_key_generate_always_db_team(mock_get_team_object): mock_get_team_object.assert_called_once() assert mock_get_team_object.call_args.kwargs["check_db_only"] == True + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "requested_model, should_pass", + [ + ("gpt-4o", True), # Should pass - exact match in aliases + ("gpt-4o-team1", True), # Should pass - team has access to this deployment + ("gpt-4o-mini", False), # Should fail - not in aliases + ("o-3", False), # Should fail - not in aliases + ], +) +async def test_team_model_alias(prisma_client, requested_model, should_pass): + """ + Test team model alias functionality: + 1. Create team with model alias = `{gpt-4o: gpt-4o-team1}` + 2. Generate key for that team with model = `gpt-4o` + 3. Verify chat completion request works with aliased model = `gpt-4o` + """ + litellm.set_verbose = True + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + await litellm.proxy.proxy_server.prisma_client.connect() + + # Create team with model alias + team_id = f"test_team_{uuid.uuid4()}" + await new_team( + data=NewTeamRequest( + team_id=team_id, + team_alias=f"test_team_alias_{uuid.uuid4()}", + models=["gpt-4o-team1"], + model_aliases={"gpt-4o": "gpt-4o-team1"}, + ), + http_request=Request(scope={"type": "http"}), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="admin" + ), + ) + + # Generate key for the team + new_key = await generate_key_fn( + data=GenerateKeyRequest( + team_id=team_id, + models=["gpt-4o-team1"], + ), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="admin" + ), + ) + + generated_key = new_key.key + + # Test chat completion request + request = Request(scope={"type": "http"}) + request._url = URL(url="/chat/completions") + + async def return_body(): + return_string = f'{{"model": "{requested_model}"}}' + return return_string.encode() + + request.body = return_body + + if should_pass: + # Verify the key works with the aliased model + result = await user_api_key_auth( + request=request, api_key=f"Bearer {generated_key}" + ) + + assert result.models == [ + "gpt-4o-team1" + ], "Expected model list to contain aliased model" + assert result.team_model_aliases == { + "gpt-4o": "gpt-4o-team1" + }, "Expected model aliases to be present" + else: + # Verify the key fails with non-aliased models + with pytest.raises(Exception) as exc_info: + await user_api_key_auth(request=request, api_key=f"Bearer {generated_key}") + assert exc_info.value.type == ProxyErrorTypes.key_model_access_denied diff --git a/tests/proxy_unit_tests/test_auth_checks.py b/tests/proxy_unit_tests/test_auth_checks.py index ad79328ade..0a8ebbe018 100644 --- a/tests/proxy_unit_tests/test_auth_checks.py +++ b/tests/proxy_unit_tests/test_auth_checks.py @@ -633,3 +633,52 @@ async def test_get_fuzzy_user_object(): mock_prisma.db.litellm_usertable.find_unique.assert_called_with( where={"sso_user_id": "sso_123"}, include={"organization_memberships": True} ) + + +@pytest.mark.parametrize( + "model, alias_map, expect_to_work", + [ + ("gpt-4", {"gpt-4": "gpt-4-team1"}, True), # model matches alias value + ("gpt-5", {"gpt-4": "gpt-4-team1"}, False), + ], +) +@pytest.mark.asyncio +async def test_can_key_call_model_with_aliases(model, alias_map, expect_to_work): + """ + Test if can_key_call_model correctly handles model aliases in the token + """ + from litellm.proxy.auth.auth_checks import can_key_call_model + + llm_model_list = [ + { + "model_name": "gpt-4-team1", + "litellm_params": { + "model": "gpt-4", + "api_key": "test-api-key", + }, + } + ] + router = litellm.Router(model_list=llm_model_list) + + user_api_key_object = UserAPIKeyAuth( + models=[ + "gpt-4-team1", + ], + team_model_aliases=alias_map, + ) + + if expect_to_work: + await can_key_call_model( + model=model, + llm_model_list=llm_model_list, + valid_token=user_api_key_object, + llm_router=router, + ) + else: + with pytest.raises(Exception) as e: + await can_key_call_model( + model=model, + llm_model_list=llm_model_list, + valid_token=user_api_key_object, + llm_router=router, + ) diff --git a/ui/litellm-dashboard/src/components/team/model_aliases_card.tsx b/ui/litellm-dashboard/src/components/team/model_aliases_card.tsx new file mode 100644 index 0000000000..ba1cdc209d --- /dev/null +++ b/ui/litellm-dashboard/src/components/team/model_aliases_card.tsx @@ -0,0 +1,174 @@ +import React, { useState } from "react"; +import { + Card, + Title, + Text, + Button as TremorButton, +} from "@tremor/react"; +import { Modal, Form, Select, Input, message } from "antd"; +import { teamUpdateCall } from "@/components/networking"; + +interface ModelAliasesCardProps { + teamId: string; + accessToken: string | null; + currentAliases: Record; + availableModels: string[]; + onUpdate: () => void; +} + +const ModelAliasesCard: React.FC = ({ + teamId, + accessToken, + currentAliases, + availableModels, + onUpdate, +}) => { + const [isModalVisible, setIsModalVisible] = useState(false); + const [form] = Form.useForm(); + + const handleCreateAlias = async (values: any) => { + try { + if (!accessToken) return; + + const newAliases = { + ...currentAliases, + [values.alias_name]: values.original_model, + }; + + const updateData = { + team_id: teamId, + model_aliases: newAliases, + }; + + await teamUpdateCall(accessToken, updateData); + message.success("Model alias created successfully"); + setIsModalVisible(false); + form.resetFields(); + currentAliases[values.alias_name] = values.original_model; + } catch (error) { + message.error("Failed to create model alias"); + console.error("Error creating model alias:", error); + } + }; + + return ( +
+ Team Aliases + + Allow a team to use an alias that points to a specific model deployment. + + + +
+
+
+
+
ALIAS
+
POINTS TO
+
+
+ setIsModalVisible(true)} + > + Create Model Alias + +
+ +
+ {Object.entries(currentAliases).map(([aliasName, originalModel], index) => ( +
+
+ + {aliasName} + +
+
+ + {originalModel} + +
+
+ ))} + {Object.keys(currentAliases).length === 0 && ( +
+ No model aliases configured +
+ )} +
+
+ + { + setIsModalVisible(false); + form.resetFields(); + }} + footer={null} + width={500} + > +
+ + + + + + + + +
+ { + setIsModalVisible(false); + form.resetFields(); + }} + className="bg-white text-gray-700 border border-gray-300 hover:bg-gray-50" + > + Cancel + + + Create Alias + +
+
+
+
+ ); +}; + +export default ModelAliasesCard; \ No newline at end of file diff --git a/ui/litellm-dashboard/src/components/team/team_info.tsx b/ui/litellm-dashboard/src/components/team/team_info.tsx index cbcdde89da..2ac3893635 100644 --- a/ui/litellm-dashboard/src/components/team/team_info.tsx +++ b/ui/litellm-dashboard/src/components/team/team_info.tsx @@ -29,6 +29,7 @@ import { PencilAltIcon, PlusIcon, TrashIcon } from "@heroicons/react/outline"; import TeamMemberModal from "./edit_membership"; import UserSearchModal from "@/components/common_components/user_search_modal"; import { getModelDisplayName } from "../key_team_helpers/fetch_available_models_team_key"; +import ModelAliasesCard from "./model_aliases_card"; interface TeamData { @@ -586,7 +587,7 @@ const TeamInfoView: React.FC = ({ /> -
+
@@ -638,6 +639,15 @@ const TeamInfoView: React.FC = ({
)} + + {/* Add Model Aliases Card */} +